diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e035677b8a..44e70fc4b8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import hashlib
+import io
import logging
from typing import (
TYPE_CHECKING,
@@ -37,6 +39,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
@@ -137,6 +140,7 @@ class UserAttributes:
localpart: Optional[str]
confirm_localpart: bool = False
display_name: Optional[str] = None
+ picture: Optional[str] = None
emails: Collection[str] = attr.Factory(list)
@@ -191,9 +195,14 @@ class SsoHandler:
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
+ self._media_repo = (
+ hs.get_media_repository() if hs.config.media.can_load_media_repo else None
+ )
+ self._http_client = hs.get_proxied_blacklisted_http_client()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
@@ -493,6 +502,8 @@ class SsoHandler:
await self._profile_handler.set_displayname(
user_id_obj, requester, attributes.display_name, True
)
+ if attributes.picture:
+ await self.set_avatar(user_id, attributes.picture)
await self._auth_handler.complete_sso_login(
user_id,
@@ -701,8 +712,110 @@ class SsoHandler:
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
+
+ # Set avatar, if available
+ if attributes.picture:
+ await self.set_avatar(registered_user_id, attributes.picture)
+
return registered_user_id
+ async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
+ """Set avatar of the user.
+
+ This downloads the image file from the URL provided, stores that in
+ the media repository and then sets the avatar on the user's profile.
+
+ It can detect if the same image is being saved again and bails early by storing
+ the hash of the file in the `upload_name` of the avatar image.
+
+ Currently, it only supports server configurations which run the media repository
+ within the same process.
+
+ It silently fails and logs a warning by raising an exception and catching it
+ internally if:
+ * it is unable to fetch the image itself (non 200 status code) or
+ * the image supplied is bigger than max allowed size or
+ * the image type is not one of the allowed image types.
+
+ Args:
+ user_id: matrix user ID in the form @localpart:domain as a string.
+
+ picture_https_url: HTTPS url for the picture image file.
+
+ Returns: `True` if the user's avatar has been successfully set to the image at
+ `picture_https_url`.
+ """
+ if self._media_repo is None:
+ logger.info(
+ "failed to set user avatar because out-of-process media repositories "
+ "are not supported yet "
+ )
+ return False
+
+ try:
+ uid = UserID.from_string(user_id)
+
+ def is_allowed_mime_type(content_type: str) -> bool:
+ if (
+ self._profile_handler.allowed_avatar_mimetypes
+ and content_type
+ not in self._profile_handler.allowed_avatar_mimetypes
+ ):
+ return False
+ return True
+
+ # download picture, enforcing size limit & mime type check
+ picture = io.BytesIO()
+
+ content_length, headers, uri, code = await self._http_client.get_file(
+ url=picture_https_url,
+ output_stream=picture,
+ max_size=self._profile_handler.max_avatar_size,
+ is_allowed_content_type=is_allowed_mime_type,
+ )
+
+ if code != 200:
+ raise Exception(
+ "GET request to download sso avatar image returned {}".format(code)
+ )
+
+ # upload name includes hash of the image file's content so that we can
+ # easily check if it requires an update or not, the next time user logs in
+ upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
+
+ # bail if user already has the same avatar
+ profile = await self._profile_handler.get_profile(user_id)
+ if profile["avatar_url"] is not None:
+ server_name = profile["avatar_url"].split("/")[-2]
+ media_id = profile["avatar_url"].split("/")[-1]
+ if server_name == self._server_name:
+ media = await self._media_repo.store.get_local_media(media_id)
+ if media is not None and upload_name == media["upload_name"]:
+ logger.info("skipping saving the user avatar")
+ return True
+
+ # store it in media repository
+ avatar_mxc_url = await self._media_repo.create_content(
+ media_type=headers[b"Content-Type"][0].decode("utf-8"),
+ upload_name=upload_name,
+ content=picture,
+ content_length=content_length,
+ auth_user=uid,
+ )
+
+ # save it as user avatar
+ await self._profile_handler.set_avatar_url(
+ uid,
+ create_requester(uid),
+ str(avatar_mxc_url),
+ )
+
+ logger.info("successfully saved the user avatar")
+ return True
+ except Exception:
+ logger.warning("failed to save the user avatar")
+ return False
+
async def complete_sso_ui_auth_request(
self,
auth_provider_id: str,
@@ -874,7 +987,7 @@ class SsoHandler:
)
async def handle_terms_accepted(
- self, request: Request, session_id: str, terms_version: str
+ self, request: SynapseRequest, session_id: str, terms_version: str
) -> None:
"""Handle a request to the new-user 'consent' endpoint
@@ -1026,6 +1139,84 @@ class SsoHandler:
return True
+ async def revoke_sessions_for_provider_session_id(
+ self,
+ auth_provider_id: str,
+ auth_provider_session_id: str,
+ expected_user_id: Optional[str] = None,
+ ) -> None:
+ """Revoke any devices and in-flight logins tied to a provider session.
+
+ Can only be called from the main process.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ auth_provider_session_id: The session ID from the provider to logout
+ expected_user_id: The user we're expecting to logout. If set, it will ignore
+ sessions belonging to other users and log an error.
+ """
+
+ # It is expected that this is the main process.
+ assert isinstance(
+ self._device_handler, DeviceHandler
+ ), "revoking SSO sessions can only be called on the main process"
+
+ # Invalidate any running user-mapping sessions
+ to_delete = []
+ for session_id, session in self._username_mapping_sessions.items():
+ if (
+ session.auth_provider_id == auth_provider_id
+ and session.auth_provider_session_id == auth_provider_session_id
+ ):
+ to_delete.append(session_id)
+
+ for session_id in to_delete:
+ logger.info("Revoking mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
+
+ # Invalidate any in-flight login tokens
+ await self._store.invalidate_login_tokens_by_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # Fetch any device(s) in the store associated with the session ID.
+ devices = await self._store.get_devices_by_auth_provider_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # We have no guarantee that all the devices of that session are for the same
+ # `user_id`. Hence, we have to iterate over the list of devices and log them out
+ # one by one.
+ for device in devices:
+ user_id = device["user_id"]
+ device_id = device["device_id"]
+
+ # If the user_id associated with that device/session is not the one we got
+ # out of the `sub` claim, skip that device and show log an error.
+ if expected_user_id is not None and user_id != expected_user_id:
+ logger.error(
+ "Received a logout notification from SSO provider "
+ f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+ f"a session ID ({auth_provider_session_id!r}) which belongs to "
+ f"{user_id!r}. This may happen when the SSO provider user mapper "
+ "uses something else than the standard attribute as mapping ID. "
+ "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+ "in the provider config if that is the case."
+ )
+ continue
+
+ logger.info(
+ "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+ user_id,
+ device_id,
+ auth_provider_id,
+ auth_provider_session_id,
+ )
+ await self._device_handler.delete_devices(user_id, [device_id])
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
|