diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3a1f150082..3fb77fd9dd 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -20,7 +20,17 @@
#
import time
import urllib.parse
-from typing import Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+ Any,
+ BinaryIO,
+ Callable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -34,8 +44,9 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
+from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, profile, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
@@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
@@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
class UsernamePickerTestCase(HomeserverTestCase):
"""Tests for the username picker flow of SSO login"""
- servlets = [login.register_servlets]
+ servlets = [
+ login.register_servlets,
+ profile.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ hs = self.setup_test_homeserver(
+ proxied_blocklisted_http_client=self.http_client
+ )
+ return hs
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
@@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)
config["oidc_config"]["user_mapping_provider"] = {
- "config": {"display_name_template": "{{ user.displayname }}"}
+ "config": {
+ "display_name_template": "{{ user.displayname }}",
+ "email_template": "{{ user.email }}",
+ "picture_template": "{{ user.picture }}",
+ }
}
# whitelist this client URI so we redirect straight to it rather than
@@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs))
return d
- def test_username_picker(self) -> None:
- """Test the happy path of a username picker flow."""
-
- fake_oidc_server = self.helper.fake_oidc_server()
-
+ def proceed_to_username_picker_page(
+ self,
+ fake_oidc_server: FakeOidcServer,
+ displayname: str,
+ email: str,
+ picture: str,
+ ) -> Tuple[str, str]:
# do the start of the login flow
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server,
- {"sub": "tester", "displayname": "Jonny"},
+ {
+ "sub": "tester",
+ "displayname": displayname,
+ "picture": picture,
+ "email": email,
+ },
TEST_CLIENT_REDIRECT_URL,
)
@@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
- self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.display_name, displayname)
+ self.assertEqual(session.emails, [email])
+ self.assertEqual(session.avatar_url, picture)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+ return picker_url, session_id
+
+ def test_username_picker_use_displayname_avatar_and_email(self) -> None:
+ """Test the happy path of a username picker flow with using displayname, avatar and email."""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ mxid = "@bobby:test"
+ displayname = "Jonny"
+ email = "bobby@test.com"
+ picture = "mxc://test/avatar_url"
+
+ picker_url, session_id = self.proceed_to_username_picker_page(
+ fake_oidc_server, displayname, email, picture
+ )
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # to the completion page.
+ # Also specify that we should use the provided displayname, avatar and email.
+ content = urlencode(
+ {
+ b"username": b"bobby",
+ b"use_display_name": b"true",
+ b"use_avatar": b"true",
+ b"use_email": email,
+ }
+ ).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=picker_url,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", "username_mapping_session=" + session_id),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # ensure that the returned location matches the requested redirect URL
+ path, query = location_headers[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # fish the login token out of the returned redirect uri
+ login_token = params[2][1]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], mxid)
+
+ # ensure the displayname and avatar from the OIDC response have been configured for the user.
+ channel = self.make_request(
+ "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertIn("mxc://test", channel.json_body["avatar_url"])
+ self.assertEqual(displayname, channel.json_body["displayname"])
+
+ # ensure the email from the OIDC response has been configured for the user.
+ channel = self.make_request(
+ "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(email, channel.json_body["threepids"][0]["address"])
+
+ def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
+ """Test the happy path of a username picker flow without using displayname, avatar or email."""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ mxid = "@bobby:test"
+ displayname = "Jonny"
+ email = "bobby@test.com"
+ picture = "mxc://test/avatar_url"
+ username = "bobby"
+
+ picker_url, session_id = self.proceed_to_username_picker_page(
+ fake_oidc_server, displayname, email, picture
+ )
+
# Now, submit a username to the username picker, which should serve a redirect
- # to the completion page
- content = urlencode({b"username": b"bobby"}).encode("utf8")
+ # to the completion page.
+ # Also specify that we should not use the provided displayname, avatar or email.
+ content = urlencode(
+ {
+ b"username": username,
+ b"use_display_name": b"false",
+ b"use_avatar": b"false",
+ }
+ ).encode("utf8")
chan = self.make_request(
"POST",
path=picker_url,
@@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase):
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
- self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+ self.assertEqual(chan.json_body["user_id"], mxid)
+
+ # ensure the displayname and avatar from the OIDC response have not been configured for the user.
+ channel = self.make_request(
+ "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertNotIn("avatar_url", channel.json_body)
+ self.assertEqual(username, channel.json_body["displayname"])
+
+ # ensure the email from the OIDC response has not been configured for the user.
+ channel = self.make_request(
+ "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertListEqual([], channel.json_body["threepids"])
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+ return 0, {b"Content-Type": [b"image/png"]}, "", 200
|