summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-01-26 14:21:13 +0000
committerGitHub <noreply@github.com>2022-01-26 14:21:13 +0000
commit2d3bd9aa670eedd299cc03093459929adec41918 (patch)
treeb7baca8830fc7b3fde9c596405097dd6c6295cfc /synapse/handlers/auth.py
parentImprovements to bundling aggregations. (#11815) (diff)
downloadsynapse-2d3bd9aa670eedd299cc03093459929adec41918.tar.xz
Add a module callback to set username at registration (#11790)
This is in the context of mainlining the Tchap fork of Synapse. Currently in Tchap usernames are derived from the user's email address (extracted from the UIA results, more specifically the m.login.email.identity step).
This change also exports the check_username method from the registration handler as part of the module API, so that a module can check if the username it's trying to generate is correct and doesn't conflict with an existing one, and fallback gracefully if not.

Co-authored-by: David Robertson <davidr@element.io>
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
         Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
     ],
 ]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 
 
 class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
         # lists of callbacks
         self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
         self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+        self.get_username_for_registration_callbacks: List[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
                 # Add the new method to the list of auth_checker_callbacks for this login type
                 self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
+        if get_username_for_registration is not None:
+            self.get_username_for_registration_callbacks.append(
+                get_username_for_registration,
+            )
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
+
+    async def get_username_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the username to use when registering the user, using the credentials
+        and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            The localpart to use when registering this user, or None if no module
+            returned a localpart.
+        """
+        for callback in self.get_username_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_username_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_username_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None