diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+ [JsonDict, JsonDict],
+ Awaitable[Optional[str]],
+]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
+ self.get_displayname_for_registration_callbacks: List[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
+ get_displayname_for_registration: Optional[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
get_username_for_registration,
)
+ if get_displayname_for_registration is not None:
+ self.get_displayname_for_registration_callbacks.append(
+ get_displayname_for_registration,
+ )
+
if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
return None
+ async def get_displayname_for_registration(
+ self,
+ uia_results: JsonDict,
+ params: JsonDict,
+ ) -> Optional[str]:
+ """Defines the display name to use when registering the user, using the
+ credentials and parameters provided during the UIA flow.
+
+ Stops at the first callback that returns a tuple containing at least one string.
+
+ Args:
+ uia_results: The credentials provided during the UIA flow.
+ params: The parameters provided by the registration request.
+
+ Returns:
+ A tuple which first element is the display name, and the second is an MXC URL
+ to the user's avatar.
+ """
+ for callback in self.get_displayname_for_registration_callbacks:
+ try:
+ res = await callback(uia_results, params)
+
+ if isinstance(res, str):
+ return res
+ elif res is not None:
+ # mypy complains that this line is unreachable because it assumes the
+ # data returned by the module fits the expected type. We just want
+ # to make sure this is the case.
+ logger.warning( # type: ignore[unreachable]
+ "Ignoring non-string value returned by"
+ " get_displayname_for_registration callback %s: %s",
+ callback,
+ res,
+ )
+ except Exception as e:
+ logger.error(
+ "Module raised an exception in get_displayname_for_registration: %s",
+ e,
+ )
+ raise SynapseError(code=500, msg="Internal Server Error")
+
+ return None
+
async def is_3pid_allowed(
self,
medium: str,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..8a17b912d3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
+ get_displayname_for_registration: Optional[
+ GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
"""Registers callbacks for password auth provider capabilities.
@@ -328,6 +332,7 @@ class ModuleApi:
is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration,
+ get_displayname_for_registration=get_displayname_for_registration,
)
def register_background_update_controller_callbacks(
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
session_id
)
+ display_name = await (
+ self.password_auth_provider.get_displayname_for_registration(
+ auth_result, params
+ )
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
+ default_display_name=display_name,
address=client_addr,
user_agent_ips=entries,
)
|