summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-01-26 14:21:13 +0000
committerGitHub <noreply@github.com>2022-01-26 14:21:13 +0000
commit2d3bd9aa670eedd299cc03093459929adec41918 (patch)
treeb7baca8830fc7b3fde9c596405097dd6c6295cfc /tests/handlers/test_password_providers.py
parentImprovements to bundling aggregations. (#11815) (diff)
downloadsynapse-2d3bd9aa670eedd299cc03093459929adec41918.tar.xz
Add a module callback to set username at registration (#11790)
This is in the context of mainlining the Tchap fork of Synapse. Currently in Tchap usernames are derived from the user's email address (extracted from the UIA results, more specifically the m.login.email.identity step).
This change also exports the check_username method from the registration handler as part of the module API, so that a module can check if the username it's trying to generate is correct and doesn't conflict with an existing one, and fallback gracefully if not.

Co-authored-by: David Robertson <davidr@element.io>
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py79
1 files changed, 77 insertions, 2 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 2add72b28a..94809cb8be 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,10 +20,11 @@ from unittest.mock import Mock
 from twisted.internet import defer
 
 import synapse
+from synapse.api.constants import LoginType
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout
-from synapse.types import JsonDict
+from synapse.rest.client import devices, login, logout, register
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login.register_servlets,
         devices.register_servlets,
         logout.register_servlets,
+        register.register_servlets,
     ]
 
     def setUp(self):
@@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         on_logged_out.assert_called_once()
         self.assertTrue(self.called)
 
+    def test_username(self):
+        """Tests that the get_username_for_registration callback can define the username
+        of a user when registering.
+        """
+        self._setup_get_username_for_registration()
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+    def test_username_uia(self):
+        """Tests that the get_username_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_username_for_registration()
+
+        # Initiate the UIA flow.
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
+    def _setup_get_username_for_registration(self) -> Mock:
+        """Registers a get_username_for_registration callback that appends "-foo" to the
+        username the client is trying to register.
+        """
+
+        async def get_username_for_registration(uia_results, params):
+            self.assertIn(LoginType.DUMMY, uia_results)
+            username = params["username"]
+            return username + "-foo"
+
+        m = Mock(side_effect=get_username_for_registration)
+
+        password_auth_provider = self.hs.get_password_auth_provider()
+        password_auth_provider.get_username_for_registration_callbacks.append(m)
+
+        return m
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)