summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-02-17 17:54:16 +0100
committerGitHub <noreply@github.com>2022-02-17 16:54:16 +0000
commit707049c6ff61193ffdfba909b4f17e9158c1d3e1 (patch)
treeb0831e4f9066abf41b2b2e89c13d821d0907cd6c /tests/handlers/test_password_providers.py
parentFaster joins: parse msc3706 fields in send_join response (#12011) (diff)
downloadsynapse-707049c6ff61193ffdfba909b4f17e9158c1d3e1.tar.xz
Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py123
1 files changed, 93 insertions, 30 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
 
     def __init__(self, config, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
-            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
         )
 
     def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
             auth_checkers={
                 ("test.login_type", ("test_field",)): self.check_auth,
                 ("m.login.password", ("password",)): self.check_auth,
-            },
+            }
         )
         pass
 
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
+    CALLBACK_USERNAME = "get_username_for_registration"
+    CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
     def setUp(self):
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback can define the username
         of a user when registering.
         """
-        self._setup_get_username_for_registration()
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
+        )
 
         username = "rin"
         channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback is only called at the
         end of the UIA flow.
         """
-        m = self._setup_get_username_for_registration()
-
-        # Initiate the UIA flow.
-        username = "rin"
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"username": username, "type": "m.login.password", "password": "bar"},
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
         )
-        self.assertEqual(channel.code, 401)
-        self.assertIn("session", channel.json_body)
 
-        # Check that the callback hasn't been called yet.
-        m.assert_not_called()
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
 
-        # Finish the UIA flow.
-        session = channel.json_body["session"]
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"auth": {"session": session, "type": LoginType.DUMMY}},
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        mxid = channel.json_body["user_id"]
+        mxid = res["user_id"]
         self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
 
         # Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self._test_3pid_allowed("rin", False)
         self._test_3pid_allowed("kitay", True)
 
+    def test_displayname(self):
+        """Tests that the get_displayname_for_registration callback can define the
+        display name of a user when registering.
+        """
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        user_id = UserID.from_string(channel.json_body["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+    def test_displayname_uia(self):
+        """Tests that the get_displayname_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
+
+        user_id = UserID.from_string(res["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
     def _test_3pid_allowed(self, username: str, registration: bool):
         """Tests that the "is_3pid_allowed" module callback is called correctly, using
         either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "bar@test.com", registration)
 
-    def _setup_get_username_for_registration(self) -> Mock:
-        """Registers a get_username_for_registration callback that appends "-foo" to the
-        username the client is trying to register.
+    def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+        """Registers either a get_username_for_registration callback or a
+        get_displayname_for_registration callback that appends "-foo" to the username the
+        client is trying to register.
         """
 
-        async def get_username_for_registration(uia_results, params):
+        async def callback(uia_results, params):
             self.assertIn(LoginType.DUMMY, uia_results)
             username = params["username"]
             return username + "-foo"
 
-        m = Mock(side_effect=get_username_for_registration)
+        m = Mock(side_effect=callback)
 
         password_auth_provider = self.hs.get_password_auth_provider()
-        password_auth_provider.get_username_for_registration_callbacks.append(m)
+        getattr(password_auth_provider, callback_name + "_callbacks").append(m)
 
         return m
 
+    def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+        # Initiate the UIA flow.
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        return channel.json_body
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)