diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
- auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+ auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)
def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
- },
+ }
)
pass
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
account.register_servlets,
]
+ CALLBACK_USERNAME = "get_username_for_registration"
+ CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
- self._setup_get_username_for_registration()
+ self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_USERNAME,
+ )
username = "rin"
channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
- m = self._setup_get_username_for_registration()
-
- # Initiate the UIA flow.
- username = "rin"
- channel = self.make_request(
- "POST",
- "register",
- {"username": username, "type": "m.login.password", "password": "bar"},
+ m = self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_USERNAME,
)
- self.assertEqual(channel.code, 401)
- self.assertIn("session", channel.json_body)
- # Check that the callback hasn't been called yet.
- m.assert_not_called()
+ username = "rin"
+ res = self._do_uia_assert_mock_not_called(username, m)
- # Finish the UIA flow.
- session = channel.json_body["session"]
- channel = self.make_request(
- "POST",
- "register",
- {"auth": {"session": session, "type": LoginType.DUMMY}},
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- mxid = channel.json_body["user_id"]
+ mxid = res["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
# Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
+ def test_displayname(self):
+ """Tests that the get_displayname_for_registration callback can define the
+ display name of a user when registering.
+ """
+ self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_DISPLAYNAME,
+ )
+
+ username = "rin"
+ channel = self.make_request(
+ "POST",
+ "/register",
+ {
+ "username": username,
+ "password": "bar",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Our callback takes the username and appends "-foo" to it, check that's what we
+ # have.
+ user_id = UserID.from_string(channel.json_body["user_id"])
+ display_name = self.get_success(
+ self.hs.get_profile_handler().get_displayname(user_id)
+ )
+
+ self.assertEqual(display_name, username + "-foo")
+
+ def test_displayname_uia(self):
+ """Tests that the get_displayname_for_registration callback is only called at the
+ end of the UIA flow.
+ """
+ m = self._setup_get_name_for_registration(
+ callback_name=self.CALLBACK_DISPLAYNAME,
+ )
+
+ username = "rin"
+ res = self._do_uia_assert_mock_not_called(username, m)
+
+ user_id = UserID.from_string(res["user_id"])
+ display_name = self.get_success(
+ self.hs.get_profile_handler().get_displayname(user_id)
+ )
+
+ self.assertEqual(display_name, username + "-foo")
+
+ # Check that the callback has been called.
+ m.assert_called_once()
+
def _test_3pid_allowed(self, username: str, registration: bool):
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "bar@test.com", registration)
- def _setup_get_username_for_registration(self) -> Mock:
- """Registers a get_username_for_registration callback that appends "-foo" to the
- username the client is trying to register.
+ def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+ """Registers either a get_username_for_registration callback or a
+ get_displayname_for_registration callback that appends "-foo" to the username the
+ client is trying to register.
"""
- async def get_username_for_registration(uia_results, params):
+ async def callback(uia_results, params):
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
- m = Mock(side_effect=get_username_for_registration)
+ m = Mock(side_effect=callback)
password_auth_provider = self.hs.get_password_auth_provider()
- password_auth_provider.get_username_for_registration_callbacks.append(m)
+ getattr(password_auth_provider, callback_name + "_callbacks").append(m)
return m
+ def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+ # Initiate the UIA flow.
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"username": username, "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel.code, 401)
+ self.assertIn("session", channel.json_body)
+
+ # Check that the callback hasn't been called yet.
+ m.assert_not_called()
+
+ # Finish the UIA flow.
+ session = channel.json_body["session"]
+ channel = self.make_request(
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": LoginType.DUMMY}},
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ return channel.json_body
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
|