diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2022-02-17 17:54:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-17 16:54:16 +0000 |
commit | 707049c6ff61193ffdfba909b4f17e9158c1d3e1 (patch) | |
tree | b0831e4f9066abf41b2b2e89c13d821d0907cd6c /synapse | |
parent | Faster joins: parse msc3706 fields in send_join response (#12011) (diff) | |
download | synapse-707049c6ff61193ffdfba909b4f17e9158c1d3e1.tar.xz |
Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/auth.py | 58 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/register.py | 7 |
3 files changed, 70 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7e..572f54b1e3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ class PasswordAuthProvider: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ class PasswordAuthProvider: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ class PasswordAuthProvider: get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ class PasswordAuthProvider: return None + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d4fca36923..8a17b912d3 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -70,6 +70,7 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -317,6 +318,9 @@ class ModuleApi: get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +332,7 @@ class ModuleApi: is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c965e2bda2..b8a5135e02 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet): session_id ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) |