summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/auth.py84
1 files changed, 52 insertions, 32 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3933ce171a..51888d1f97 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -59,7 +59,6 @@ class AuthHandler(BaseHandler):
         }
         self.bcrypt_rounds = hs.config.bcrypt_rounds
         self.sessions = {}
-        self.INVALID_TOKEN_HTTP_STATUS = 401
 
         self.ldap_enabled = hs.config.ldap_enabled
         if self.ldap_enabled:
@@ -149,13 +148,19 @@ class AuthHandler(BaseHandler):
         creds = session['creds']
 
         # check auth type currently being presented
+        errordict = {}
         if 'type' in authdict:
             if authdict['type'] not in self.checkers:
                 raise LoginError(400, "", Codes.UNRECOGNIZED)
-            result = yield self.checkers[authdict['type']](authdict, clientip)
-            if result:
-                creds[authdict['type']] = result
-                self._save_session(session)
+            try:
+                result = yield self.checkers[authdict['type']](authdict, clientip)
+                if result:
+                    creds[authdict['type']] = result
+                    self._save_session(session)
+            except LoginError, e:
+                # this step failed. Merge the error dict into the response
+                # so that the client can have another go.
+                errordict = e.error_dict()
 
         for f in flows:
             if len(set(f) - set(creds.keys())) == 0:
@@ -164,6 +169,7 @@ class AuthHandler(BaseHandler):
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
+        ret.update(errordict)
         defer.returnValue((False, ret, clientdict, session['id']))
 
     @defer.inlineCallbacks
@@ -431,37 +437,40 @@ class AuthHandler(BaseHandler):
             defer.Deferred: (str) canonical_user_id, or None if zero or
             multiple matches
         """
-        try:
-            res = yield self._find_user_id_and_pwd_hash(user_id)
+        res = yield self._find_user_id_and_pwd_hash(user_id)
+        if res is not None:
             defer.returnValue(res[0])
-        except LoginError:
-            defer.returnValue(None)
+        defer.returnValue(None)
 
     @defer.inlineCallbacks
     def _find_user_id_and_pwd_hash(self, user_id):
         """Checks to see if a user with the given id exists. Will check case
-        insensitively, but will throw if there are multiple inexact matches.
+        insensitively, but will return None if there are multiple inexact
+        matches.
 
         Returns:
             tuple: A 2-tuple of `(canonical_user_id, password_hash)`
+            None: if there is not exactly one match
         """
         user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
+
+        result = None
         if not user_infos:
             logger.warn("Attempted to login as %s but they do not exist", user_id)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-        if len(user_infos) > 1:
-            if user_id not in user_infos:
-                logger.warn(
-                    "Attempted to login as %s but it matches more than one user "
-                    "inexactly: %r",
-                    user_id, user_infos.keys()
-                )
-                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
-            defer.returnValue((user_id, user_infos[user_id]))
+        elif len(user_infos) == 1:
+            # a single match (possibly not exact)
+            result = user_infos.popitem()
+        elif user_id in user_infos:
+            # multiple matches, but one is exact
+            result = (user_id, user_infos[user_id])
         else:
-            defer.returnValue(user_infos.popitem())
+            # multiple matches, none of them exact
+            logger.warn(
+                "Attempted to login as %s but it matches more than one user "
+                "inexactly: %r",
+                user_id, user_infos.keys()
+            )
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def _check_password(self, user_id, password):
@@ -475,34 +484,45 @@ class AuthHandler(BaseHandler):
         Returns:
             (str) the canonical_user_id
         Raises:
-            LoginError if the password was incorrect
+            LoginError if login fails
         """
         valid_ldap = yield self._check_ldap_password(user_id, password)
         if valid_ldap:
             defer.returnValue(user_id)
 
-        result = yield self._check_local_password(user_id, password)
-        defer.returnValue(result)
+        canonical_user_id = yield self._check_local_password(user_id, password)
+
+        if canonical_user_id:
+            defer.returnValue(canonical_user_id)
+
+        # unknown username or invalid password. We raise a 403 here, but note
+        # that if we're doing user-interactive login, it turns all LoginErrors
+        # into a 401 anyway.
+        raise LoginError(
+            403, "Invalid password",
+            errcode=Codes.FORBIDDEN
+        )
 
     @defer.inlineCallbacks
     def _check_local_password(self, user_id, password):
         """Authenticate a user against the local password database.
 
-        user_id is checked case insensitively, but will throw if there are
+        user_id is checked case insensitively, but will return None if there are
         multiple inexact matches.
 
         Args:
             user_id (str): complete @user:id
         Returns:
-            (str) the canonical_user_id
-        Raises:
-            LoginError if the password was incorrect
+            (str) the canonical_user_id, or None if unknown user / bad password
         """
-        user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+        lookupres = yield self._find_user_id_and_pwd_hash(user_id)
+        if not lookupres:
+            defer.returnValue(None)
+        (user_id, password_hash) = lookupres
         result = self.validate_hash(password, password_hash)
         if not result:
             logger.warn("Failed password login for user %s", user_id)
-            raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+            defer.returnValue(None)
         defer.returnValue(user_id)
 
     def _ldap_simple_bind(self, server, localpart, password):