summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 8e39e76c97..f275d4f35a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -169,6 +169,7 @@ class UsernameMappingSession:
     # attributes returned by the ID mapper
     display_name: Optional[str]
     emails: StrCollection
+    avatar_url: Optional[str]
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -183,6 +184,7 @@ class UsernameMappingSession:
     # choices made by the user
     chosen_localpart: Optional[str] = None
     use_display_name: bool = True
+    use_avatar: bool = True
     emails_to_use: StrCollection = ()
     terms_accepted_version: Optional[str] = None
 
@@ -660,6 +662,9 @@ class SsoHandler:
             remote_user_id=remote_user_id,
             display_name=attributes.display_name,
             emails=attributes.emails,
+            avatar_url=attributes.picture,
+            # Default to using all mapped emails. Will be overwritten in handle_submit_username_request.
+            emails_to_use=attributes.emails,
             client_redirect_url=client_redirect_url,
             expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
             extra_login_attributes=extra_login_attributes,
@@ -966,6 +971,7 @@ class SsoHandler:
         session_id: str,
         localpart: str,
         use_display_name: bool,
+        use_avatar: bool,
         emails_to_use: Iterable[str],
     ) -> None:
         """Handle a request to the username-picker 'submit' endpoint
@@ -988,6 +994,7 @@ class SsoHandler:
         # update the session with the user's choices
         session.chosen_localpart = localpart
         session.use_display_name = use_display_name
+        session.use_avatar = use_avatar
 
         emails_from_idp = set(session.emails)
         filtered_emails: Set[str] = set()
@@ -1068,6 +1075,9 @@ class SsoHandler:
         if session.use_display_name:
             attributes.display_name = session.display_name
 
+        if session.use_avatar:
+            attributes.picture = session.avatar_url
+
         # the following will raise a 400 error if the username has been taken in the
         # meantime.
         user_id = await self._register_mapped_user(