diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 95b0cfeb48..573c9db8a1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,7 +49,6 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
- LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
@@ -78,15 +77,20 @@ class AuthHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- login_types = set()
+ # we keep this as a list despite the O(N^2) implication so that we can
+ # keep PASSWORD first and avoid confusing clients which pick the first
+ # type in the list. (NB that the spec doesn't require us to do so and
+ # clients which favour types that they don't understand over those that
+ # they do are technically broken)
+ login_types = []
if self._password_enabled:
- login_types.add(LoginType.PASSWORD)
+ login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
- login_types.update(
- provider.get_supported_login_types().keys()
- )
- self._supported_login_types = frozenset(login_types)
+ for t in provider.get_supported_login_types().keys():
+ if t not in login_types:
+ login_types.append(t)
+ self._supported_login_types = login_types
@defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip):
@@ -116,14 +120,27 @@ class AuthHandler(BaseHandler):
a different user to `requester`
"""
- # we only support password login here
- flows = [[LoginType.PASSWORD]]
+ # build a list of supported flows
+ flows = [
+ [login_type] for login_type in self._supported_login_types
+ ]
result, params, _ = yield self.check_auth(
flows, request_body, clientip,
)
- user_id = result[LoginType.PASSWORD]
+ # find the completed login type
+ for login_type in self._supported_login_types:
+ if login_type not in result:
+ continue
+
+ user_id = result[login_type]
+ break
+ else:
+ # this can't happen
+ raise Exception(
+ "check_auth returned True but no successful login type",
+ )
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
@@ -210,14 +227,12 @@ class AuthHandler(BaseHandler):
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
- if login_type not in self.checkers:
- raise LoginError(400, "", Codes.UNRECOGNIZED)
try:
- result = yield self.checkers[login_type](authdict, clientip)
+ result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
- except LoginError, e:
+ except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
@@ -226,7 +241,7 @@ class AuthHandler(BaseHandler):
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
- raise e
+ raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
@@ -323,17 +338,35 @@ class AuthHandler(BaseHandler):
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
- def _check_password_auth(self, authdict, _):
- if "user" not in authdict or "password" not in authdict:
- raise LoginError(400, "", Codes.MISSING_PARAM)
+ def _check_auth_dict(self, authdict, clientip):
+ """Attempt to validate the auth dict provided by a client
+
+ Args:
+ authdict (object): auth dict provided by the client
+ clientip (str): IP address of the client
+
+ Returns:
+ Deferred: result of the stage verification.
+
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ login_type = authdict['type']
+ checker = self.checkers.get(login_type)
+ if checker is not None:
+ res = yield checker(authdict, clientip)
+ defer.returnValue(res)
+
+ # build a v1-login-style dict out of the authdict and fall back to the
+ # v1 code
+ user_id = authdict.get("user")
- user_id = authdict["user"]
- password = authdict["password"]
+ if user_id is None:
+ raise SynapseError(400, "", Codes.MISSING_PARAM)
- (canonical_id, callback) = yield self.validate_login(user_id, {
- "type": LoginType.PASSWORD,
- "password": password,
- })
+ (canonical_id, callback) = yield self.validate_login(user_id, authdict)
defer.returnValue(canonical_id)
@defer.inlineCallbacks
|