summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-02-17 17:54:16 +0100
committerGitHub <noreply@github.com>2022-02-17 16:54:16 +0000
commit707049c6ff61193ffdfba909b4f17e9158c1d3e1 (patch)
treeb0831e4f9066abf41b2b2e89c13d821d0907cd6c /synapse
parentFaster joins: parse msc3706 fields in send_join response (#12011) (diff)
downloadsynapse-707049c6ff61193ffdfba909b4f17e9158c1d3e1.tar.xz
Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to '')
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/module_api/__init__.py5
-rw-r--r--synapse/rest/client/register.py7
3 files changed, 70 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.get_displayname_for_registration_callbacks: List[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
         self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if get_displayname_for_registration is not None:
+            self.get_displayname_for_registration_callbacks.append(
+                get_displayname_for_registration,
+            )
+
         if is_3pid_allowed is not None:
             self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
 
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
 
         return None
 
+    async def get_displayname_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the display name to use when registering the user, using the
+        credentials and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a tuple containing at least one string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            A tuple which first element is the display name, and the second is an MXC URL
+            to the user's avatar.
+        """
+        for callback in self.get_displayname_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_displayname_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_displayname_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
+
     async def is_3pid_allowed(
         self,
         medium: str,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..8a17b912d3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
+    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
     GET_USERNAME_FOR_REGISTRATION_CALLBACK,
     IS_3PID_ALLOWED_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
@@ -328,6 +332,7 @@ class ModuleApi:
             is_3pid_allowed=is_3pid_allowed,
             auth_checkers=auth_checkers,
             get_username_for_registration=get_username_for_registration,
+            get_displayname_for_registration=get_displayname_for_registration,
         )
 
     def register_background_update_controller_callbacks(
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
                 session_id
             )
 
+            display_name = await (
+                self.password_auth_provider.get_displayname_for_registration(
+                    auth_result, params
+                )
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
+                default_display_name=display_name,
                 address=client_addr,
                 user_agent_ips=entries,
             )