summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 250b03a025..b9347b87c7 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet):
         self,
         user_id: str,
         login_submission: JsonDict,
-        callback: Optional[
-            Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
-        ] = None,
+        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
@@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet):
         Args:
             user_id: ID of the user to register.
             login_submission: Dictionary of login information.
-            callback: Callback function to run after registration.
+            callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
 
         Returns:
-            result: Dictionary of account information after successful registration.
+            result: Dictionary of account information after successful login.
         """
 
         # Before we actually log them in we check if they've already logged in
@@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet):
         return result
 
     async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+        """
+        Handle the final stage of SSO login.
+
+        Args:
+             login_submission: The JSON request body.
+
+        Returns:
+            The body of the JSON response.
+        """
         token = login_submission["token"]
         auth_handler = self.auth_handler
         user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
             token
         )
 
-        result = await self._complete_login(user_id, login_submission)
-        return result
+        return await self._complete_login(
+            user_id, login_submission, self.auth_handler._sso_login_callback
+        )
 
     async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
         token = login_submission.get("token", None)