summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_login.py204
1 files changed, 190 insertions, 14 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3a1f150082..3fb77fd9dd 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -20,7 +20,17 @@
 #
 import time
 import urllib.parse
-from typing import Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+    Any,
+    BinaryIO,
+    Callable,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
@@ -34,8 +44,9 @@ import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
+from synapse.http.client import RawHeaders
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, profile, register
 from synapse.rest.client.account import WhoamiRestServlet
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer
@@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2
 from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
@@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 class UsernamePickerTestCase(HomeserverTestCase):
     """Tests for the username picker flow of SSO login"""
 
-    servlets = [login.register_servlets]
+    servlets = [
+        login.register_servlets,
+        profile.register_servlets,
+        account.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.http_client = Mock(spec=["get_file"])
+        self.http_client.get_file.side_effect = mock_get_file
+        hs = self.setup_test_homeserver(
+            proxied_blocklisted_http_client=self.http_client
+        )
+        return hs
 
     def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
@@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         config["oidc_config"] = {}
         config["oidc_config"].update(TEST_OIDC_CONFIG)
         config["oidc_config"]["user_mapping_provider"] = {
-            "config": {"display_name_template": "{{ user.displayname }}"}
+            "config": {
+                "display_name_template": "{{ user.displayname }}",
+                "email_template": "{{ user.email }}",
+                "picture_template": "{{ user.picture }}",
+            }
         }
 
         # whitelist this client URI so we redirect straight to it rather than
@@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase):
         d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
-    def test_username_picker(self) -> None:
-        """Test the happy path of a username picker flow."""
-
-        fake_oidc_server = self.helper.fake_oidc_server()
-
+    def proceed_to_username_picker_page(
+        self,
+        fake_oidc_server: FakeOidcServer,
+        displayname: str,
+        email: str,
+        picture: str,
+    ) -> Tuple[str, str]:
         # do the start of the login flow
         channel, _ = self.helper.auth_via_oidc(
             fake_oidc_server,
-            {"sub": "tester", "displayname": "Jonny"},
+            {
+                "sub": "tester",
+                "displayname": displayname,
+                "picture": picture,
+                "email": email,
+            },
             TEST_CLIENT_REDIRECT_URL,
         )
 
@@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         session = username_mapping_sessions[session_id]
         self.assertEqual(session.remote_user_id, "tester")
-        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.display_name, displayname)
+        self.assertEqual(session.emails, [email])
+        self.assertEqual(session.avatar_url, picture)
         self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
 
         # the expiry time should be about 15 minutes away
         expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
+        return picker_url, session_id
+
+    def test_username_picker_use_displayname_avatar_and_email(self) -> None:
+        """Test the happy path of a username picker flow with using displayname, avatar and email."""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        mxid = "@bobby:test"
+        displayname = "Jonny"
+        email = "bobby@test.com"
+        picture = "mxc://test/avatar_url"
+
+        picker_url, session_id = self.proceed_to_username_picker_page(
+            fake_oidc_server, displayname, email, picture
+        )
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # to the completion page.
+        # Also specify that we should use the provided displayname, avatar and email.
+        content = urlencode(
+            {
+                b"username": b"bobby",
+                b"use_display_name": b"true",
+                b"use_avatar": b"true",
+                b"use_email": email,
+            }
+        ).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=picker_url,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
+
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # fish the login token out of the returned redirect uri
+        login_token = params[2][1]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], mxid)
+
+        # ensure the displayname and avatar from the OIDC response have been configured for the user.
+        channel = self.make_request(
+            "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertIn("mxc://test", channel.json_body["avatar_url"])
+        self.assertEqual(displayname, channel.json_body["displayname"])
+
+        # ensure the email from the OIDC response has been configured for the user.
+        channel = self.make_request(
+            "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(email, channel.json_body["threepids"][0]["address"])
+
+    def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
+        """Test the happy path of a username picker flow without using displayname, avatar or email."""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        mxid = "@bobby:test"
+        displayname = "Jonny"
+        email = "bobby@test.com"
+        picture = "mxc://test/avatar_url"
+        username = "bobby"
+
+        picker_url, session_id = self.proceed_to_username_picker_page(
+            fake_oidc_server, displayname, email, picture
+        )
+
         # Now, submit a username to the username picker, which should serve a redirect
-        # to the completion page
-        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        # to the completion page.
+        # Also specify that we should not use the provided displayname, avatar or email.
+        content = urlencode(
+            {
+                b"username": username,
+                b"use_display_name": b"false",
+                b"use_avatar": b"false",
+            }
+        ).encode("utf8")
         chan = self.make_request(
             "POST",
             path=picker_url,
@@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase):
             content={"type": "m.login.token", "token": login_token},
         )
         self.assertEqual(chan.code, 200, chan.result)
-        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+        self.assertEqual(chan.json_body["user_id"], mxid)
+
+        # ensure the displayname and avatar from the OIDC response have not been configured for the user.
+        channel = self.make_request(
+            "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertNotIn("avatar_url", channel.json_body)
+        self.assertEqual(username, channel.json_body["displayname"])
+
+        # ensure the email from the OIDC response has not been configured for the user.
+        channel = self.make_request(
+            "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertListEqual([], channel.json_body["threepids"])
+
+
+async def mock_get_file(
+    url: str,
+    output_stream: BinaryIO,
+    max_size: Optional[int] = None,
+    headers: Optional[RawHeaders] = None,
+    is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+    return 0, {b"Content-Type": [b"image/png"]}, "", 200