summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
         Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
     ],
 ]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 
 
 class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
         # lists of callbacks
         self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
         self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+        self.get_username_for_registration_callbacks: List[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
                 # Add the new method to the list of auth_checker_callbacks for this login type
                 self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
+        if get_username_for_registration is not None:
+            self.get_username_for_registration_callbacks.append(
+                get_username_for_registration,
+            )
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
+
+    async def get_username_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the username to use when registering the user, using the credentials
+        and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            The localpart to use when registering this user, or None if no module
+            returned a localpart.
+        """
+        for callback in self.get_username_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_username_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_username_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None