summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17120.bugfix1
-rw-r--r--docs/sso_mapping_providers.md1
-rw-r--r--synapse/handlers/sso.py10
-rw-r--r--synapse/rest/synapse/client/pick_username.py4
-rw-r--r--tests/rest/client/test_login.py204
5 files changed, 205 insertions, 15 deletions
diff --git a/changelog.d/17120.bugfix b/changelog.d/17120.bugfix
new file mode 100644
index 0000000000..85b34c2e98
--- /dev/null
+++ b/changelog.d/17120.bugfix
@@ -0,0 +1 @@
+Apply user email & picture during OIDC registration if present & selected.
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 10c695029f..d6c4e860ae 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -98,6 +98,7 @@ A custom mapping provider must specify the following methods:
         either accept this localpart or pick their own username. Otherwise this
         option has no effect. If omitted, defaults to `False`.
       - `display_name`: An optional string, the display name for the user.
+      - `picture`: An optional string, the avatar url for the user.
       - `emails`: A list of strings, the email address(es) to associate with
         this user. If omitted, defaults to an empty list.
 * `async def get_extra_attributes(self, userinfo, token)`
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 8e39e76c97..f275d4f35a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -169,6 +169,7 @@ class UsernameMappingSession:
     # attributes returned by the ID mapper
     display_name: Optional[str]
     emails: StrCollection
+    avatar_url: Optional[str]
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -183,6 +184,7 @@ class UsernameMappingSession:
     # choices made by the user
     chosen_localpart: Optional[str] = None
     use_display_name: bool = True
+    use_avatar: bool = True
     emails_to_use: StrCollection = ()
     terms_accepted_version: Optional[str] = None
 
@@ -660,6 +662,9 @@ class SsoHandler:
             remote_user_id=remote_user_id,
             display_name=attributes.display_name,
             emails=attributes.emails,
+            avatar_url=attributes.picture,
+            # Default to using all mapped emails. Will be overwritten in handle_submit_username_request.
+            emails_to_use=attributes.emails,
             client_redirect_url=client_redirect_url,
             expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
             extra_login_attributes=extra_login_attributes,
@@ -966,6 +971,7 @@ class SsoHandler:
         session_id: str,
         localpart: str,
         use_display_name: bool,
+        use_avatar: bool,
         emails_to_use: Iterable[str],
     ) -> None:
         """Handle a request to the username-picker 'submit' endpoint
@@ -988,6 +994,7 @@ class SsoHandler:
         # update the session with the user's choices
         session.chosen_localpart = localpart
         session.use_display_name = use_display_name
+        session.use_avatar = use_avatar
 
         emails_from_idp = set(session.emails)
         filtered_emails: Set[str] = set()
@@ -1068,6 +1075,9 @@ class SsoHandler:
         if session.use_display_name:
             attributes.display_name = session.display_name
 
+        if session.use_avatar:
+            attributes.picture = session.avatar_url
+
         # the following will raise a 400 error if the username has been taken in the
         # meantime.
         user_id = await self._register_mapped_user(
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index e671774aeb..7d16b796d4 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -113,6 +113,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
                 "display_name": session.display_name,
                 "emails": session.emails,
                 "localpart": localpart,
+                "avatar_url": session.avatar_url,
             },
         }
 
@@ -134,6 +135,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
         try:
             localpart = parse_string(request, "username", required=True)
             use_display_name = parse_boolean(request, "use_display_name", default=False)
+            use_avatar = parse_boolean(request, "use_avatar", default=False)
 
             try:
                 emails_to_use: List[str] = [
@@ -147,5 +149,5 @@ class AccountDetailsResource(DirectServeHtmlResource):
             return
 
         await self._sso_handler.handle_submit_username_request(
-            request, session_id, localpart, use_display_name, emails_to_use
+            request, session_id, localpart, use_display_name, use_avatar, emails_to_use
         )
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3a1f150082..3fb77fd9dd 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -20,7 +20,17 @@
 #
 import time
 import urllib.parse
-from typing import Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+    Any,
+    BinaryIO,
+    Callable,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
@@ -34,8 +44,9 @@ import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
+from synapse.http.client import RawHeaders
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, profile, register
 from synapse.rest.client.account import WhoamiRestServlet
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer
@@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2
 from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
@@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 class UsernamePickerTestCase(HomeserverTestCase):
     """Tests for the username picker flow of SSO login"""
 
-    servlets = [login.register_servlets]
+    servlets = [
+        login.register_servlets,
+        profile.register_servlets,
+        account.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.http_client = Mock(spec=["get_file"])
+        self.http_client.get_file.side_effect = mock_get_file
+        hs = self.setup_test_homeserver(
+            proxied_blocklisted_http_client=self.http_client
+        )
+        return hs
 
     def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
@@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         config["oidc_config"] = {}
         config["oidc_config"].update(TEST_OIDC_CONFIG)
         config["oidc_config"]["user_mapping_provider"] = {
-            "config": {"display_name_template": "{{ user.displayname }}"}
+            "config": {
+                "display_name_template": "{{ user.displayname }}",
+                "email_template": "{{ user.email }}",
+                "picture_template": "{{ user.picture }}",
+            }
         }
 
         # whitelist this client URI so we redirect straight to it rather than
@@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase):
         d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
-    def test_username_picker(self) -> None:
-        """Test the happy path of a username picker flow."""
-
-        fake_oidc_server = self.helper.fake_oidc_server()
-
+    def proceed_to_username_picker_page(
+        self,
+        fake_oidc_server: FakeOidcServer,
+        displayname: str,
+        email: str,
+        picture: str,
+    ) -> Tuple[str, str]:
         # do the start of the login flow
         channel, _ = self.helper.auth_via_oidc(
             fake_oidc_server,
-            {"sub": "tester", "displayname": "Jonny"},
+            {
+                "sub": "tester",
+                "displayname": displayname,
+                "picture": picture,
+                "email": email,
+            },
             TEST_CLIENT_REDIRECT_URL,
         )
 
@@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         session = username_mapping_sessions[session_id]
         self.assertEqual(session.remote_user_id, "tester")
-        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.display_name, displayname)
+        self.assertEqual(session.emails, [email])
+        self.assertEqual(session.avatar_url, picture)
         self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
 
         # the expiry time should be about 15 minutes away
         expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
+        return picker_url, session_id
+
+    def test_username_picker_use_displayname_avatar_and_email(self) -> None:
+        """Test the happy path of a username picker flow with using displayname, avatar and email."""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        mxid = "@bobby:test"
+        displayname = "Jonny"
+        email = "bobby@test.com"
+        picture = "mxc://test/avatar_url"
+
+        picker_url, session_id = self.proceed_to_username_picker_page(
+            fake_oidc_server, displayname, email, picture
+        )
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # to the completion page.
+        # Also specify that we should use the provided displayname, avatar and email.
+        content = urlencode(
+            {
+                b"username": b"bobby",
+                b"use_display_name": b"true",
+                b"use_avatar": b"true",
+                b"use_email": email,
+            }
+        ).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=picker_url,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
+
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # fish the login token out of the returned redirect uri
+        login_token = params[2][1]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], mxid)
+
+        # ensure the displayname and avatar from the OIDC response have been configured for the user.
+        channel = self.make_request(
+            "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertIn("mxc://test", channel.json_body["avatar_url"])
+        self.assertEqual(displayname, channel.json_body["displayname"])
+
+        # ensure the email from the OIDC response has been configured for the user.
+        channel = self.make_request(
+            "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(email, channel.json_body["threepids"][0]["address"])
+
+    def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
+        """Test the happy path of a username picker flow without using displayname, avatar or email."""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        mxid = "@bobby:test"
+        displayname = "Jonny"
+        email = "bobby@test.com"
+        picture = "mxc://test/avatar_url"
+        username = "bobby"
+
+        picker_url, session_id = self.proceed_to_username_picker_page(
+            fake_oidc_server, displayname, email, picture
+        )
+
         # Now, submit a username to the username picker, which should serve a redirect
-        # to the completion page
-        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        # to the completion page.
+        # Also specify that we should not use the provided displayname, avatar or email.
+        content = urlencode(
+            {
+                b"username": username,
+                b"use_display_name": b"false",
+                b"use_avatar": b"false",
+            }
+        ).encode("utf8")
         chan = self.make_request(
             "POST",
             path=picker_url,
@@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase):
             content={"type": "m.login.token", "token": login_token},
         )
         self.assertEqual(chan.code, 200, chan.result)
-        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+        self.assertEqual(chan.json_body["user_id"], mxid)
+
+        # ensure the displayname and avatar from the OIDC response have not been configured for the user.
+        channel = self.make_request(
+            "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertNotIn("avatar_url", channel.json_body)
+        self.assertEqual(username, channel.json_body["displayname"])
+
+        # ensure the email from the OIDC response has not been configured for the user.
+        channel = self.make_request(
+            "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertListEqual([], channel.json_body["threepids"])
+
+
+async def mock_get_file(
+    url: str,
+    output_stream: BinaryIO,
+    max_size: Optional[int] = None,
+    headers: Optional[RawHeaders] = None,
+    is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+    return 0, {b"Content-Type": [b"image/png"]}, "", 200