summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12009.feature1
-rw-r--r--docs/modules/password_auth_provider_callbacks.md35
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/module_api/__init__.py5
-rw-r--r--synapse/rest/client/register.py7
-rw-r--r--tests/handlers/test_password_providers.py123
6 files changed, 195 insertions, 34 deletions
diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature
new file mode 100644
index 0000000000..c8a531481e
--- /dev/null
+++ b/changelog.d/12009.feature
@@ -0,0 +1 @@
+Enable modules to set a custom display name when registering a user.
diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md
index 88b59bb09e..ec810fd292 100644
--- a/docs/modules/password_auth_provider_callbacks.md
+++ b/docs/modules/password_auth_provider_callbacks.md
@@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `None`, Synapse falls through to the next one. The value of the first
 callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
 the authentication is denied.
 
 ### `on_logged_out`
@@ -162,10 +162,38 @@ return `None`.
 If multiple modules implement this callback, they will be considered in order. If a
 callback returns `None`, Synapse falls through to the next one. The value of the first
 callback that does not return `None` will be used. If this happens, Synapse will not call
-any of the subsequent implementations of this callback. If every callback return `None`,
+any of the subsequent implementations of this callback. If every callback returns `None`,
 the username provided by the user is used, if any (otherwise one is automatically
 generated).
 
+### `get_displayname_for_registration`
+
+_First introduced in Synapse v1.54.0_
+
+```python
+async def get_displayname_for_registration(
+    uia_results: Dict[str, Any],
+    params: Dict[str, Any],
+) -> Optional[str]
+```
+
+Called when registering a new user. The module can return a display name to set for the
+user being registered by returning it as a string, or `None` if it doesn't wish to force a
+display name for this user.
+
+This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
+has been completed by the user. It is not called when registering a user via SSO. It is
+passed two dictionaries, which include the information that the user has provided during
+the registration process. These dictionaries are identical to the ones passed to
+[`get_username_for_registration`](#get_username_for_registration), so refer to the
+documentation of this callback for more information about them.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `None`, Synapse falls through to the next one. The value of the first
+callback that does not return `None` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback. If every callback returns `None`,
+the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
+
 ## `is_3pid_allowed`
 
 _First introduced in Synapse v1.53.0_
@@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo
     - Is checked by the method: `self.check_my_login`
 - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
     - Expects a `password` field to be sent to `/login`
-    - Is checked by the method: `self.check_pass` 
-
+    - Is checked by the method: `self.check_pass`
 
 ```python
 from typing import Awaitable, Callable, Optional, Tuple
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.get_displayname_for_registration_callbacks: List[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
         self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if get_displayname_for_registration is not None:
+            self.get_displayname_for_registration_callbacks.append(
+                get_displayname_for_registration,
+            )
+
         if is_3pid_allowed is not None:
             self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
 
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
 
         return None
 
+    async def get_displayname_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the display name to use when registering the user, using the
+        credentials and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a tuple containing at least one string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            A tuple which first element is the display name, and the second is an MXC URL
+            to the user's avatar.
+        """
+        for callback in self.get_displayname_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_displayname_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_displayname_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
+
     async def is_3pid_allowed(
         self,
         medium: str,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..8a17b912d3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
+    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
     GET_USERNAME_FOR_REGISTRATION_CALLBACK,
     IS_3PID_ALLOWED_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
@@ -328,6 +332,7 @@ class ModuleApi:
             is_3pid_allowed=is_3pid_allowed,
             auth_checkers=auth_checkers,
             get_username_for_registration=get_username_for_registration,
+            get_displayname_for_registration=get_displayname_for_registration,
         )
 
     def register_background_update_controller_callbacks(
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
                 session_id
             )
 
+            display_name = await (
+                self.password_auth_provider.get_displayname_for_registration(
+                    auth_result, params
+                )
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
+                default_display_name=display_name,
                 address=client_addr,
                 user_agent_ips=entries,
             )
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4740dd0a65..49d832de81 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -84,7 +84,7 @@ class CustomAuthProvider:
 
     def __init__(self, config, api: ModuleApi):
         api.register_password_auth_provider_callbacks(
-            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
         )
 
     def check_auth(self, *args):
@@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
             auth_checkers={
                 ("test.login_type", ("test_field",)): self.check_auth,
                 ("m.login.password", ("password",)): self.check_auth,
-            },
+            }
         )
         pass
 
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
+    CALLBACK_USERNAME = "get_username_for_registration"
+    CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
+
     def setUp(self):
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
@@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback can define the username
         of a user when registering.
         """
-        self._setup_get_username_for_registration()
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
+        )
 
         username = "rin"
         channel = self.make_request(
@@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """Tests that the get_username_for_registration callback is only called at the
         end of the UIA flow.
         """
-        m = self._setup_get_username_for_registration()
-
-        # Initiate the UIA flow.
-        username = "rin"
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"username": username, "type": "m.login.password", "password": "bar"},
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_USERNAME,
         )
-        self.assertEqual(channel.code, 401)
-        self.assertIn("session", channel.json_body)
 
-        # Check that the callback hasn't been called yet.
-        m.assert_not_called()
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
 
-        # Finish the UIA flow.
-        session = channel.json_body["session"]
-        channel = self.make_request(
-            "POST",
-            "register",
-            {"auth": {"session": session, "type": LoginType.DUMMY}},
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        mxid = channel.json_body["user_id"]
+        mxid = res["user_id"]
         self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
 
         # Check that the callback has been called.
@@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self._test_3pid_allowed("rin", False)
         self._test_3pid_allowed("kitay", True)
 
+    def test_displayname(self):
+        """Tests that the get_displayname_for_registration callback can define the
+        display name of a user when registering.
+        """
+        self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        channel = self.make_request(
+            "POST",
+            "/register",
+            {
+                "username": username,
+                "password": "bar",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Our callback takes the username and appends "-foo" to it, check that's what we
+        # have.
+        user_id = UserID.from_string(channel.json_body["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+    def test_displayname_uia(self):
+        """Tests that the get_displayname_for_registration callback is only called at the
+        end of the UIA flow.
+        """
+        m = self._setup_get_name_for_registration(
+            callback_name=self.CALLBACK_DISPLAYNAME,
+        )
+
+        username = "rin"
+        res = self._do_uia_assert_mock_not_called(username, m)
+
+        user_id = UserID.from_string(res["user_id"])
+        display_name = self.get_success(
+            self.hs.get_profile_handler().get_displayname(user_id)
+        )
+
+        self.assertEqual(display_name, username + "-foo")
+
+        # Check that the callback has been called.
+        m.assert_called_once()
+
     def _test_3pid_allowed(self, username: str, registration: bool):
         """Tests that the "is_3pid_allowed" module callback is called correctly, using
         either /register or /account URLs depending on the arguments.
@@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "bar@test.com", registration)
 
-    def _setup_get_username_for_registration(self) -> Mock:
-        """Registers a get_username_for_registration callback that appends "-foo" to the
-        username the client is trying to register.
+    def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
+        """Registers either a get_username_for_registration callback or a
+        get_displayname_for_registration callback that appends "-foo" to the username the
+        client is trying to register.
         """
 
-        async def get_username_for_registration(uia_results, params):
+        async def callback(uia_results, params):
             self.assertIn(LoginType.DUMMY, uia_results)
             username = params["username"]
             return username + "-foo"
 
-        m = Mock(side_effect=get_username_for_registration)
+        m = Mock(side_effect=callback)
 
         password_auth_provider = self.hs.get_password_auth_provider()
-        password_auth_provider.get_username_for_registration_callbacks.append(m)
+        getattr(password_auth_provider, callback_name + "_callbacks").append(m)
 
         return m
 
+    def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
+        # Initiate the UIA flow.
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Check that the callback hasn't been called yet.
+        m.assert_not_called()
+
+        # Finish the UIA flow.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        return channel.json_body
+
     def _get_login_flows(self) -> JsonDict:
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)