diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ff2c66f442..058a0f416d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -162,7 +162,8 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
- yield self._check_password(user_id, password)
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ self._check_password(user_id, password, password_hash)
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -283,23 +284,37 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
- yield self._check_password(user_id, password)
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ self._check_password(user_id, password, password_hash)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
- defer.returnValue(access_token)
+ defer.returnValue((user_id, access_token))
@defer.inlineCallbacks
- def _check_password(self, user_id, password):
- """Checks that user_id has passed password, raises LoginError if not."""
- user_info = yield self.store.get_user_by_id(user_id=user_id)
- if not user_info:
+ def _find_user_id_and_pwd_hash(self, user_id):
+ user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
+ if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- stored_hash = user_info["password_hash"]
+ if len(user_infos) > 1:
+ if user_id not in user_infos:
+ logger.warn(
+ "Attempted to login as %s but it matches more than one user "
+ "inexactly: %r",
+ user_id, user_infos.keys()
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ defer.returnValue((user_id, user_infos[user_id]))
+ else:
+ defer.returnValue(user_infos.popitem())
+
+ def _check_password(self, user_id, password, stored_hash):
+ """Checks that user_id has passed password, raises LoginError if not."""
if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0d5eafd0fa..2444f27366 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,9 +83,10 @@ class LoginRestServlet(ClientV1RestServlet):
if not user_id.startswith('@'):
user_id = UserID.create(
- user_id, self.hs.hostname).to_string()
+ user_id, self.hs.hostname
+ ).to_string()
- token = yield self.handlers.auth_handler.login_with_password(
+ user_id, token = yield self.handlers.auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"])
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 25adecaf6d..586628579d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -99,13 +99,16 @@ class RegistrationStore(SQLBaseStore):
)
def get_users_by_id_case_insensitive(self, user_id):
+ """Gets users that match user_id case insensitively.
+ Returns a mapping of user_id -> password_hash.
+ """
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
- " WHERE name = lower(?)"
+ " WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
- return self.cursor_to_dict(txn)
+ return dict(txn.fetchall())
return self.runInteraction("get_users_by_id_case_insensitive", f)
|