diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+ [JsonDict, JsonDict],
+ Awaitable[Optional[str]],
+]
class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+ self.get_username_for_registration_callbacks: List[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
+ get_username_for_registration: Optional[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
+ if get_username_for_registration is not None:
+ self.get_username_for_registration_callbacks.append(
+ get_username_for_registration,
+ )
+
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
+
+ async def get_username_for_registration(
+ self,
+ uia_results: JsonDict,
+ params: JsonDict,
+ ) -> Optional[str]:
+ """Defines the username to use when registering the user, using the credentials
+ and parameters provided during the UIA flow.
+
+ Stops at the first callback that returns a string.
+
+ Args:
+ uia_results: The credentials provided during the UIA flow.
+ params: The parameters provided by the registration request.
+
+ Returns:
+ The localpart to use when registering this user, or None if no module
+ returned a localpart.
+ """
+ for callback in self.get_username_for_registration_callbacks:
+ try:
+ res = await callback(uia_results, params)
+
+ if isinstance(res, str):
+ return res
+ elif res is not None:
+ # mypy complains that this line is unreachable because it assumes the
+ # data returned by the module fits the expected type. We just want
+ # to make sure this is the case.
+ logger.warning( # type: ignore[unreachable]
+ "Ignoring non-string value returned by"
+ " get_username_for_registration callback %s: %s",
+ callback,
+ res,
+ )
+ except Exception as e:
+ logger.error(
+ "Module raised an exception in get_username_for_registration: %s",
+ e,
+ )
+ raise SynapseError(code=500, msg="Internal Server Error")
+
+ return None
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 662e60bc33..788b2e47d5 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -71,6 +71,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
@@ -177,6 +178,7 @@ class ModuleApi:
self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
+ self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
@@ -310,6 +312,9 @@ class ModuleApi:
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
+ get_username_for_registration: Optional[
+ GET_USERNAME_FOR_REGISTRATION_CALLBACK
+ ] = None,
) -> None:
"""Registers callbacks for password auth provider capabilities.
@@ -319,6 +324,7 @@ class ModuleApi:
check_3pid_auth=check_3pid_auth,
on_logged_out=on_logged_out,
auth_checkers=auth_checkers,
+ get_username_for_registration=get_username_for_registration,
)
def register_background_update_controller_callbacks(
@@ -1202,6 +1208,22 @@ class ModuleApi:
"""
return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+ async def check_username(self, username: str) -> None:
+ """Checks if the provided username uses the grammar defined in the Matrix
+ specification, and is already being used by an existing user.
+
+ Added in Synapse v1.52.0.
+
+ Args:
+ username: The username to check. This is the local part of the user's full
+ Matrix user ID, i.e. it's "alice" if the full user ID is "@alice:foo.com".
+
+ Raises:
+ SynapseError with the errcode "M_USER_IN_USE" if the username is already in
+ use.
+ """
+ await self._registration_handler.check_username(username)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c59dae7c03..e3492f9f93 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -425,6 +425,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
self._refresh_tokens_enabled = (
hs.config.registration.refreshable_access_token_lifetime is not None
@@ -638,7 +639,16 @@ class RegisterRestServlet(RestServlet):
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
- desired_username = params.get("username", None)
+ desired_username = await (
+ self.password_auth_provider.get_username_for_registration(
+ auth_result,
+ params,
+ )
+ )
+
+ if desired_username is None:
+ desired_username = params.get("username", None)
+
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
|