diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a0cf37a9f9..97b21c4093 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
"""
# build a list of supported flows
- flows = [
- [login_type] for login_type in self._supported_login_types
- ]
+ flows = [[login_type] for login_type in self._supported_login_types]
- result, params, _ = yield self.check_auth(
- flows, request_body, clientip,
- )
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
# find the completed login type
for login_type in self._supported_login_types:
@@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
break
else:
# this can't happen
- raise Exception(
- "check_auth returned True but no successful login type",
- )
+ raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
@@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
authdict = None
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- del clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ del clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
session = self._get_session_info(sid)
if len(clientdict) > 0:
@@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
- session['clientdict'] = clientdict
+ session["clientdict"] = clientdict
self._save_session(session)
- elif 'clientdict' in session:
- clientdict = session['clientdict']
+ elif "clientdict" in session:
+ clientdict = session["clientdict"]
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session),
+ self._auth_dict_for_flows(flows, session)
)
- if 'creds' not in session:
- session['creds'] = {}
- creds = session['creds']
+ if "creds" not in session:
+ session["creds"] = {}
+ creds = session["creds"]
# check auth type currently being presented
errordict = {}
- if 'type' in authdict:
- login_type = authdict['type']
+ if "type" in authdict:
+ login_type = authdict["type"]
try:
result = yield self._check_auth_dict(
- authdict, clientip, password_servlet=password_servlet,
+ authdict, clientip, password_servlet=password_servlet
)
if result:
creds[login_type] = result
@@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, list(clientdict)
+ creds,
+ list(clientdict),
)
- defer.returnValue((creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session["id"]))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = list(creds)
+ ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(
- ret,
- )
+ raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
- if 'session' not in authdict:
+ if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(
- authdict['session']
- )
- if 'creds' not in sess:
- sess['creds'] = {}
- creds = sess['creds']
+ sess = self._get_session_info(authdict["session"])
+ if "creds" not in sess:
+ sess["creds"] = {}
+ creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
if result:
@@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
not send a session ID, returns None.
"""
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
return sid
def set_session_data(self, session_id, key, value):
@@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
value (any): The data to store
"""
sess = self._get_session_info(session_id)
- sess.setdefault('serverdict', {})[key] = value
+ sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
@@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
- return sess.setdefault('serverdict', {}).get(key, default)
+ return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip, password_servlet=False):
@@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
- login_type = authdict['type']
+ login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
# XXX: Temporary workaround for having Synapse handle password resets
# See AuthHandler.check_auth for further details
res = yield checker(
- authdict,
- clientip=clientip,
- password_servlet=password_servlet,
+ authdict, clientip=clientip, password_servlet=password_servlet
)
defer.returnValue(res)
@@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
- 400, "Captcha response is required",
- errcode=Codes.CAPTCHA_NEEDED
+ 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
- "Submitting recaptcha response %s with remoteip %s",
- user_response, clientip
+ "Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
@@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
- 'secret': self.hs.config.recaptcha_private_key,
- 'response': user_response,
- 'remoteip': clientip,
- }
+ "secret": self.hs.config.recaptcha_private_key,
+ "response": user_response,
+ "remoteip": clientip,
+ },
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
- if 'success' in resp_body:
+ if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
- "Successful" if resp_body['success'] else "Failed",
- resp_body.get('hostname')
+ "Successful" if resp_body["success"] else "Failed",
+ resp_body.get("hostname"),
)
- if resp_body['success']:
+ if resp_body["success"]:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
- return self._check_threepid('email', authdict, **kwargs)
+ return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
- return self._check_threepid('msisdn', authdict)
+ return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
@@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
- if 'threepid_creds' not in authdict:
+ if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
- threepid_creds = authdict['threepid_creds']
+ threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
@@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
validated=True,
)
- threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
- } if row else None
+ threepid = (
+ {
+ "medium": row["medium"],
+ "address": row["address"],
+ "validated_at": row["validated_at"],
+ }
+ if row
+ else None
+ )
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
- raise SynapseError(400, "Password resets are not enabled on this homeserver")
+ raise SynapseError(
+ 400, "Password resets are not enabled on this homeserver"
+ )
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
- if threepid['medium'] != medium:
+ if threepid["medium"] != medium:
raise LoginError(
401,
- "Expecting threepid of type '%s', got '%s'" % (
- medium, threepid['medium'],
- ),
- errcode=Codes.UNAUTHORIZED
+ "Expecting threepid of type '%s', got '%s'"
+ % (medium, threepid["medium"]),
+ errcode=Codes.UNAUTHORIZED,
)
- threepid['threepid_creds'] = authdict['threepid_creds']
+ threepid["threepid_creds"] = authdict["threepid_creds"]
defer.returnValue(threepid)
@@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
- "url": "%s_matrix/consent?v=%s" % (
+ "url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
- },
- },
+ }
+ }
}
def _auth_dict_for_flows(self, flows, session):
@@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
- "session": session['id'],
+ "session": session["id"],
"flows": [{"stages": f} for f in public_flows],
- "params": params
+ "params": params,
}
def _get_session_info(self, session_id):
@@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
- self.sessions[session_id] = {
- "id": session_id,
- }
+ self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
@@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
- user_id, user_infos.keys()
+ user_id,
+ user_infos.keys(),
)
defer.returnValue(result)
@@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
user is too high too proceed.
"""
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
- qualified_user_id = UserID(
- username, self.hs.hostname
- ).to_string()
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
self.ratelimit_login_per_account(qualified_user_id)
@@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
- if (hasattr(provider, "check_password")
- and login_type == LoginType.PASSWORD):
+ if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(
- qualified_user_id, password,
- )
+ is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
defer.returnValue((qualified_user_id, None))
- if (not hasattr(provider, "get_supported_login_types")
- or not hasattr(provider, "check_auth")):
+ if not hasattr(provider, "get_supported_login_types") or not hasattr(
+ provider, "check_auth"
+ ):
# this password provider doesn't understand custom login types
continue
@@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
- 400, "Missing parameters for login type %s: %s" % (
- login_type,
- missing_fields,
- ),
+ 400,
+ "Missing parameters for login type %s: %s"
+ % (login_type, missing_fields),
)
- result = yield provider.check_auth(
- username, login_type, login_dict,
- )
+ result = yield provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
canonical_user_id = yield self._check_local_password(
- qualified_user_id, password,
+ qualified_user_id, password
)
if canonical_user_id:
@@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
# unknown username or invalid password.
self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), time_now_s=self._clock.time(),
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
@@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
- raise LoginError(
- 403, "Invalid password",
- errcode=Codes.FORBIDDEN
- )
+ raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
@@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(
- medium, address, password,
- )
+ result = yield provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token,
- device_id)
+ yield self.store.add_access_token_to_user(user_id, access_token, device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
@@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"], )
+ str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
- def delete_access_tokens_for_user(self, user_id, except_token_id=None,
- device_id=None):
+ def delete_access_tokens_for_user(
+ self, user_id, except_token_id=None, device_id=None
+ ):
"""Invalidate access tokens belonging to a user
Args:
@@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
- user_id, except_token_id=except_token_id, device_id=device_id,
+ user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
@@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
- user_id=user_id,
- device_id=device_id,
- access_token=token,
+ user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_id, (token_id for _, token_id, _ in tokens_and_devices),
+ user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
@@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
# of specific types of threepid (and fixes the fact that checking
# for the presence of an email address during password reset was
# case sensitive).
- if medium == 'email':
+ if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
- user_id, medium, address, validated_at,
- self.hs.get_clock().time_msec()
+ user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
@@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
"""
# 'Canonicalise' email addresses as per above
- if medium == 'email':
+ if medium == "email":
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
- user_id,
- {
- 'medium': medium,
- 'address': address,
- 'id_server': id_server,
- },
+ user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(
- user_id, medium, address,
- )
+ yield self.store.user_delete_threepid(user_id, medium, address)
defer.returnValue(result)
def _save_session(self, session):
@@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(unicode): Hashed password.
"""
+
def _do_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
- ).decode('ascii')
+ ).decode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
@@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
- stored_hash
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ stored_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
- stored_hash = stored_hash.encode('ascii')
+ stored_hash = stored_hash.encode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
@@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
for this user is too high too proceed.
"""
self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._account_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
@@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
- macaroon.add_first_party_caveat("nonce = %s" % (
- stringutils.random_string_with_symbols(16),
- ))
+ macaroon.add_first_party_caveat(
+ "nonce = %s" % (stringutils.random_string_with_symbols(16),)
+ )
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
@@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key,
+ )
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
|