summary refs log tree commit diff
path: root/tests/handlers/test_password_providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_password_providers.py')
-rw-r--r--tests/handlers/test_password_providers.py79
1 files changed, 77 insertions, 2 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 2add72b28a..94809cb8be 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,10 +20,11 @@ from unittest.mock import Mock
 from twisted.internet import defer
 
 import synapse
+from synapse.api.constants import LoginType
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout
-from synapse.types import JsonDict
+from synapse.rest.client import devices, login, logout, register
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login.register_servlets,
         devices.register_servlets,
         logout.register_servlets,
+        register.register_servlets,
     ]
 
     def setUp(self):
@@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         on_logged_out.assert_called_once()
         self.assertTrue(self.called)
 
+    def test_username(self):
+        """Tests that the get_username_for_registration callback can define the username
+        of a user when registering.
+        """
+        self._setup_get_username_for_registration()
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+    def test_username_uia(self):
+        """Tests that the get_username_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_username_for_registration()
+
+        # Initiate the UIA flow.
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
+    def _setup_get_username_for_registration(self) -> Mock:
+        """Registers a get_username_for_registration callback that appends "-foo" to the
+        username the client is trying to register.
+        """
+
+        async def get_username_for_registration(uia_results, params):
+            self.assertIn(LoginType.DUMMY, uia_results)
+            username = params["username"]
+            return username + "-foo"
+
+        m = Mock(side_effect=get_username_for_registration)
+
+        password_auth_provider = self.hs.get_password_auth_provider()
+        password_auth_provider.get_username_for_registration_callbacks.append(m)
+
+        return m
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)