summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
authorCallum Brown <callum@calcuode.com>2021-08-18 13:13:35 +0100
committerGitHub <noreply@github.com>2021-08-18 08:13:35 -0400
commit6e613a10d072c32e72d6b97b2d178bb840769f3e (patch)
tree0dec01aa171113e8fc9d5ca5cf7a1069edc50deb /synapse/handlers/auth.py
parentRefactor `on_receive_pdu` code (#10615) (diff)
downloadsynapse-6e613a10d072c32e72d6b97b2d178bb840769f3e.tar.xz
Display an error page during failure of fallback UIA. (#10561)
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 161b3c933c..98d3d2d97f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
 
     async def add_oob_auth(
         self, stagetype: str, authdict: Dict[str, Any], clientip: str
-    ) -> bool:
+    ) -> None:
         """
         Adds the result of out-of-band authentication into an existing auth
         session. Currently used for adding the result of fallback auth.
+
+        Raises:
+            LoginError if the stagetype is unknown or the session is missing.
+            LoginError is raised by check_auth if authentication fails.
         """
         if stagetype not in self.checkers:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+            raise LoginError(
+                400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
+            )
         if "session" not in authdict:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+            raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
 
+        # If authentication fails a LoginError is raised. Otherwise, store
+        # the successful result.
         result = await self.checkers[stagetype].check_auth(authdict, clientip)
-        if result:
-            await self.store.mark_ui_auth_stage_complete(
-                authdict["session"], stagetype, result
-            )
-            return True
-        return False
+        await self.store.mark_ui_auth_stage_complete(
+            authdict["session"], stagetype, result
+        )
 
     def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
         """