diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 5a0ed9d6b9..983994fa95 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -230,7 +230,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
- @defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -240,11 +239,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- defer.returnValue(user_id)
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -348,67 +343,66 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
- @defer.inlineCallbacks
- def login_with_password(self, user_id, password):
+ def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
- user_id (str): User ID
+ user_id (str): complete @user:id
password (str): Password
Returns:
- A tuple of:
- The user's ID.
- The access token for the user's session.
- The refresh token for the user's session.
+ defer.Deferred: (str) canonical user id
Raises:
- StoreError if there was a problem storing the token.
+ StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
-
- if not (yield self._check_password(user_id, password)):
- logger.warn("Failed password login for user %s", user_id)
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- logger.info("Logging in user %s", user_id)
- access_token = yield self.issue_access_token(user_id)
- refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id):
"""
Gets login tuple for the user with the given user ID.
+
+ Creates a new access/refresh token for the user.
+
The user is assumed to have been authenticated by some other
- machanism (e.g. CAS)
+ machanism (e.g. CAS), and the user_id converted to the canonical case.
Args:
- user_id (str): User ID
+ user_id (str): canonical User ID
Returns:
A tuple of:
- The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
- user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
-
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
- defer.returnValue((user_id, access_token, refresh_token))
+ defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks
- def does_user_exist(self, user_id):
+ def check_user_exists(self, user_id):
+ """
+ Checks to see if a user with the given id exists. Will check case
+ insensitively, but return None if there are multiple inexact matches.
+
+ Args:
+ (str) user_id: complete @user:id
+
+ Returns:
+ defer.Deferred: (str) canonical_user_id, or None if zero or
+ multiple matches
+ """
try:
- yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(True)
+ res = yield self._find_user_id_and_pwd_hash(user_id)
+ defer.returnValue(res[0])
except LoginError:
- defer.returnValue(False)
+ defer.returnValue(None)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -438,27 +432,45 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_password(self, user_id, password):
- """
+ """Authenticate a user against the LDAP and local databases.
+
+ user_id is checked case insensitively against the local database, but
+ will throw if there are multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
Returns:
- True if the user_id successfully authenticated
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
- defer.returnValue(True)
+ defer.returnValue(user_id)
- valid_local_password = yield self._check_local_password(user_id, password)
- if valid_local_password:
- defer.returnValue(True)
-
- defer.returnValue(False)
+ result = yield self._check_local_password(user_id, password)
+ defer.returnValue(result)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
- try:
- user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
- defer.returnValue(self.validate_hash(password, password_hash))
- except LoginError:
- defer.returnValue(False)
+ """Authenticate a user against the local password database.
+
+ user_id is checked case insensitively, but will throw if there are
+ multiple inexact matches.
+
+ Args:
+ user_id (str): complete @user:id
+ Returns:
+ (str) the canonical_user_id
+ Raises:
+ LoginError if the password was incorrect
+ """
+ user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
+ result = self.validate_hash(password, password_hash)
+ if not result:
+ logger.warn("Failed password login for user %s", user_id)
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+ defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
@@ -570,7 +582,7 @@ class AuthHandler(BaseHandler):
)
# check for existing account, if none exists, create one
- if not (yield self.does_user_exist(user_id)):
+ if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 8df9d10efa..a1f2ba8773 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string()
auth_handler = self.auth_handler
- user_id, access_token, refresh_token = yield auth_handler.login_with_password(
+ user_id = yield auth_handler.validate_password_login(
user_id=user_id,
- password=login_submission["password"])
-
+ password=login_submission["password"],
+ )
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(user_id)
+ )
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
@@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
- user_id, access_token, refresh_token = (
+ access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
@@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if registered_user_id:
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(
+ registered_user_id
+ )
)
result = {
- "user_id": user_id, # may have changed
+ "user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
@@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if registered_user_id:
+ access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(registered_user_id)
)
result = {
- "user_id": user_id, # may have changed
+ "user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
@@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if not user_exists:
- user_id, _ = (
+ registered_user_id = yield auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
- login_token = auth_handler.generate_short_term_login_token(user_id)
+ login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
|