summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11790.feature1
-rw-r--r--docs/modules/password_auth_provider_callbacks.md62
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/module_api/__init__.py22
-rw-r--r--synapse/rest/client/register.py12
-rw-r--r--tests/handlers/test_password_providers.py79
6 files changed, 231 insertions, 3 deletions
diff --git a/changelog.d/11790.feature b/changelog.d/11790.feature
new file mode 100644
index 0000000000..4a5cc8ec37
--- /dev/null
+++ b/changelog.d/11790.feature
@@ -0,0 +1 @@
+Add a module callback to set username at registration.
diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md
index e53abf6409..ec8324d292 100644
--- a/docs/modules/password_auth_provider_callbacks.md
+++ b/docs/modules/password_auth_provider_callbacks.md
@@ -105,6 +105,68 @@ device ID), and the (now deactivated) access token.
 
 If multiple modules implement this callback, Synapse runs them all in order.
 
+### `get_username_for_registration`
+
+_First introduced in Synapse v1.52.0_
+
+```python
+async def get_username_for_registration(
+    uia_results: Dict[str, Any],
+    params: Dict[str, Any],
+) -> Optional[str]
+```
+
+Called when registering a new user. The module can return a username to set for the user
+being registered by returning it as a string, or `None` if it doesn't wish to force a
+username for this user. If a username is returned, it will be used as the local part of a
+user's full Matrix ID (e.g. it's `alice` in `@alice:example.com`).
+
+This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
+has been completed by the user. It is not called when registering a user via SSO. It is
+passed two dictionaries, which include the information that the user has provided during
+the registration process.
+
+The first dictionary contains the results of the [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
+flow followed by the user. Its keys are the identifiers of every step involved in the flow,
+associated with either a boolean value indicating whether the step was correctly completed,
+or additional information (e.g. email address, phone number...). A list of most existing
+identifiers can be found in the [Matrix specification](https://spec.matrix.org/v1.1/client-server-api/#authentication-types).
+Here's an example featuring all currently supported keys:
+
+```python
+{
+    "m.login.dummy": True,  # Dummy authentication
+    "m.login.terms": True,  # User has accepted the terms of service for the homeserver
+    "m.login.recaptcha": True,  # User has completed the recaptcha challenge
+    "m.login.email.identity": {  # User has provided and verified an email address
+        "medium": "email",
+        "address": "alice@example.com",
+        "validated_at": 1642701357084,
+    },
+    "m.login.msisdn": {  # User has provided and verified a phone number
+        "medium": "msisdn",
+        "address": "33123456789",
+        "validated_at": 1642701357084,
+    },
+    "org.matrix.msc3231.login.registration_token": "sometoken",  # User has registered through the flow described in MSC3231
+}
+```
+
+The second dictionary contains the parameters provided by the user's client in the request
+to `/_matrix/client/v3/register`. See the [Matrix specification](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3register)
+for a complete list of these parameters.
+
+If the module cannot, or does not wish to, generate a username for this user, it must
+return `None`.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `None`, Synapse falls through to the next one. The value of the first
+callback that does not return `None` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback. If every callback return `None`,
+the username provided by the user is used, if any (otherwise one is automatically
+generated).
+
+
 ## Example
 
 The example module below implements authentication checkers for two different login types: 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
         Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
     ],
 ]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 
 
 class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
         # lists of callbacks
         self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
         self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+        self.get_username_for_registration_callbacks: List[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
                 # Add the new method to the list of auth_checker_callbacks for this login type
                 self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
+        if get_username_for_registration is not None:
+            self.get_username_for_registration_callbacks.append(
+                get_username_for_registration,
+            )
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
+
+    async def get_username_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the username to use when registering the user, using the credentials
+        and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            The localpart to use when registering this user, or None if no module
+            returned a localpart.
+        """
+        for callback in self.get_username_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_username_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_username_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 662e60bc33..788b2e47d5 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -71,6 +71,7 @@ from synapse.handlers.account_validity import (
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
+    GET_USERNAME_FOR_REGISTRATION_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
     AuthHandler,
 )
@@ -177,6 +178,7 @@ class ModuleApi:
         self._presence_stream = hs.get_event_sources().sources.presence
         self._state = hs.get_state_handler()
         self._clock: Clock = hs.get_clock()
+        self._registration_handler = hs.get_registration_handler()
         self._send_email_handler = hs.get_send_email_handler()
         self.custom_template_dir = hs.config.server.custom_template_directory
 
@@ -310,6 +312,9 @@ class ModuleApi:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
@@ -319,6 +324,7 @@ class ModuleApi:
             check_3pid_auth=check_3pid_auth,
             on_logged_out=on_logged_out,
             auth_checkers=auth_checkers,
+            get_username_for_registration=get_username_for_registration,
         )
 
     def register_background_update_controller_callbacks(
@@ -1202,6 +1208,22 @@ class ModuleApi:
         """
         return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
 
+    async def check_username(self, username: str) -> None:
+        """Checks if the provided username uses the grammar defined in the Matrix
+        specification, and is already being used by an existing user.
+
+        Added in Synapse v1.52.0.
+
+        Args:
+            username: The username to check. This is the local part of the user's full
+                Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com".
+
+        Raises:
+            SynapseError with the errcode "M_USER_IN_USE" if the username is already in
+            use.
+        """
+        await self._registration_handler.check_username(username)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c59dae7c03..e3492f9f93 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet):
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
+        self.password_auth_provider = hs.get_password_auth_provider()
         self._registration_enabled = self.hs.config.registration.enable_registration
         self._refresh_tokens_enabled = (
             hs.config.registration.refreshable_access_token_lifetime is not None
@@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet):
             if not password_hash:
                 raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
-            desired_username = params.get("username", None)
+            desired_username = await (
+                self.password_auth_provider.get_username_for_registration(
+                    auth_result,
+                    params,
+                )
+            )
+
+            if desired_username is None:
+                desired_username = params.get("username", None)
+
             guest_access_token = params.get("guest_access_token", None)
 
             if desired_username is not None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 2add72b28a..94809cb8be 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,10 +20,11 @@ from unittest.mock import Mock
 from twisted.internet import defer
 
 import synapse
+from synapse.api.constants import LoginType
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout
-from synapse.types import JsonDict
+from synapse.rest.client import devices, login, logout, register
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -156,6 +157,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login.register_servlets,
         devices.register_servlets,
         logout.register_servlets,
+        register.register_servlets,
     ]
 
     def setUp(self):
@@ -745,6 +747,79 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         on_logged_out.assert_called_once()
         self.assertTrue(self.called)
 
+    def test_username(self):
+        """Tests that the get_username_for_registration callback can define the username
+        of a user when registering.
+        """
+        self._setup_get_username_for_registration()
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+    def test_username_uia(self):
+        """Tests that the get_username_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_username_for_registration()
+
+        # Initiate the UIA flow.
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        mxid = channel.json_body["user_id"]
+        self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
+    def _setup_get_username_for_registration(self) -> Mock:
+        """Registers a get_username_for_registration callback that appends "-foo" to the
+        username the client is trying to register.
+        """
+
+        async def get_username_for_registration(uia_results, params):
+            self.assertIn(LoginType.DUMMY, uia_results)
+            username = params["username"]
+            return username + "-foo"
+
+        m = Mock(side_effect=get_username_for_registration)
+
+        password_auth_provider = self.hs.get_password_auth_provider()
+        password_auth_provider.get_username_for_registration_callbacks.append(m)
+
+        return m
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)