diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 266 |
1 files changed, 121 insertions, 145 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a0cf37a9f9..97b21c4093 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -134,13 +134,9 @@ class AuthHandler(BaseHandler): """ # build a list of supported flows - flows = [ - [login_type] for login_type in self._supported_login_types - ] + flows = [[login_type] for login_type in self._supported_login_types] - result, params, _ = yield self.check_auth( - flows, request_body, clientip, - ) + result, params, _ = yield self.check_auth(flows, request_body, clientip) # find the completed login type for login_type in self._supported_login_types: @@ -151,9 +147,7 @@ class AuthHandler(BaseHandler): break else: # this can't happen - raise Exception( - "check_auth returned True but no successful login type", - ) + raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token if user_id != requester.user.to_string(): @@ -215,11 +209,11 @@ class AuthHandler(BaseHandler): authdict = None sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - del clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + del clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] session = self._get_session_info(sid) if len(clientdict) > 0: @@ -232,27 +226,27 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - session['clientdict'] = clientdict + session["clientdict"] = clientdict self._save_session(session) - elif 'clientdict' in session: - clientdict = session['clientdict'] + elif "clientdict" in session: + clientdict = session["clientdict"] if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session), + self._auth_dict_for_flows(flows, session) ) - if 'creds' not in session: - session['creds'] = {} - creds = session['creds'] + if "creds" not in session: + session["creds"] = {} + creds = session["creds"] # check auth type currently being presented errordict = {} - if 'type' in authdict: - login_type = authdict['type'] + if "type" in authdict: + login_type = authdict["type"] try: result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet, + authdict, clientip, password_servlet=password_servlet ) if result: creds[login_type] = result @@ -281,16 +275,15 @@ class AuthHandler(BaseHandler): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, list(clientdict) + creds, + list(clientdict), ) - defer.returnValue((creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session["id"])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = list(creds) + ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError( - ret, - ) + raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -300,15 +293,13 @@ class AuthHandler(BaseHandler): """ if stagetype not in self.checkers: raise LoginError(400, "", Codes.MISSING_PARAM) - if 'session' not in authdict: + if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info( - authdict['session'] - ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + sess = self._get_session_info(authdict["session"]) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] result = yield self.checkers[stagetype](authdict, clientip) if result: @@ -329,10 +320,10 @@ class AuthHandler(BaseHandler): not send a session ID, returns None. """ sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] return sid def set_session_data(self, session_id, key, value): @@ -347,7 +338,7 @@ class AuthHandler(BaseHandler): value (any): The data to store """ sess = self._get_session_info(session_id) - sess.setdefault('serverdict', {})[key] = value + sess.setdefault("serverdict", {})[key] = value self._save_session(sess) def get_session_data(self, session_id, key, default=None): @@ -360,7 +351,7 @@ class AuthHandler(BaseHandler): default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault('serverdict', {}).get(key, default) + return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks def _check_auth_dict(self, authdict, clientip, password_servlet=False): @@ -378,15 +369,13 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - login_type = authdict['type'] + login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: # XXX: Temporary workaround for having Synapse handle password resets # See AuthHandler.check_auth for further details res = yield checker( - authdict, - clientip=clientip, - password_servlet=password_servlet, + authdict, clientip=clientip, password_servlet=password_servlet ) defer.returnValue(res) @@ -408,13 +397,11 @@ class AuthHandler(BaseHandler): # Client tried to provide captcha but didn't give the parameter: # bad request. raise LoginError( - 400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED ) logger.info( - "Submitting recaptcha response %s with remoteip %s", - user_response, clientip + "Submitting recaptcha response %s with remoteip %s", user_response, clientip ) # TODO: get this from the homeserver rather than creating a new one for @@ -424,34 +411,34 @@ class AuthHandler(BaseHandler): resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ - 'secret': self.hs.config.recaptcha_private_key, - 'response': user_response, - 'remoteip': clientip, - } + "secret": self.hs.config.recaptcha_private_key, + "response": user_response, + "remoteip": clientip, + }, ) except PartialDownloadError as pde: # Twisted is silly data = pde.response resp_body = json.loads(data) - if 'success' in resp_body: + if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly # intend the CAPTCHA to be presented by whatever client the # user is using, we just care that they have completed a CAPTCHA. logger.info( "%s reCAPTCHA from hostname %s", - "Successful" if resp_body['success'] else "Failed", - resp_body.get('hostname') + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), ) - if resp_body['success']: + if resp_body["success"]: defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid('email', authdict, **kwargs) + return self._check_threepid("email", authdict, **kwargs) def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid('msisdn', authdict) + return self._check_threepid("msisdn", authdict) def _check_dummy_auth(self, authdict, **kwargs): return defer.succeed(True) @@ -461,10 +448,10 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if 'threepid_creds' not in authdict: + if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - threepid_creds = authdict['threepid_creds'] + threepid_creds = authdict["threepid_creds"] identity_handler = self.hs.get_handlers().identity_handler @@ -482,31 +469,36 @@ class AuthHandler(BaseHandler): validated=True, ) - threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } if row else None + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) if row: # Valid threepid returned, delete from the db yield self.store.delete_threepid_session(threepid_creds["sid"]) else: - raise SynapseError(400, "Password resets are not enabled on this homeserver") + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - if threepid['medium'] != medium: + if threepid["medium"] != medium: raise LoginError( 401, - "Expecting threepid of type '%s', got '%s'" % ( - medium, threepid['medium'], - ), - errcode=Codes.UNAUTHORIZED + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, ) - threepid['threepid_creds'] = authdict['threepid_creds'] + threepid["threepid_creds"] = authdict["threepid_creds"] defer.returnValue(threepid) @@ -520,13 +512,14 @@ class AuthHandler(BaseHandler): "version": self.hs.config.user_consent_version, "en": { "name": self.hs.config.user_consent_policy_name, - "url": "%s_matrix/consent?v=%s" % ( + "url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), }, - }, - }, + } + } } def _auth_dict_for_flows(self, flows, session): @@ -547,9 +540,9 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session['id'], + "session": session["id"], "flows": [{"stages": f} for f in public_flows], - "params": params + "params": params, } def _get_session_info(self, session_id): @@ -560,9 +553,7 @@ class AuthHandler(BaseHandler): # create a new session while session_id is None or session_id in self.sessions: session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - } + self.sessions[session_id] = {"id": session_id} return self.sessions[session_id] @@ -652,7 +643,8 @@ class AuthHandler(BaseHandler): logger.warn( "Attempted to login as %s but it matches more than one user " "inexactly: %r", - user_id, user_infos.keys() + user_id, + user_infos.keys(), ) defer.returnValue(result) @@ -690,12 +682,10 @@ class AuthHandler(BaseHandler): user is too high too proceed. """ - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: - qualified_user_id = UserID( - username, self.hs.hostname - ).to_string() + qualified_user_id = UserID(username, self.hs.hostname).to_string() self.ratelimit_login_per_account(qualified_user_id) @@ -713,17 +703,15 @@ class AuthHandler(BaseHandler): raise SynapseError(400, "Missing parameter: password") for provider in self.password_providers: - if (hasattr(provider, "check_password") - and login_type == LoginType.PASSWORD): + if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password( - qualified_user_id, password, - ) + is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: defer.returnValue((qualified_user_id, None)) - if (not hasattr(provider, "get_supported_login_types") - or not hasattr(provider, "check_auth")): + if not hasattr(provider, "get_supported_login_types") or not hasattr( + provider, "check_auth" + ): # this password provider doesn't understand custom login types continue @@ -744,15 +732,12 @@ class AuthHandler(BaseHandler): login_dict[f] = login_submission[f] if missing_fields: raise SynapseError( - 400, "Missing parameters for login type %s: %s" % ( - login_type, - missing_fields, - ), + 400, + "Missing parameters for login type %s: %s" + % (login_type, missing_fields), ) - result = yield provider.check_auth( - username, login_type, login_dict, - ) + result = yield provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -762,7 +747,7 @@ class AuthHandler(BaseHandler): known_login_type = True canonical_user_id = yield self._check_local_password( - qualified_user_id, password, + qualified_user_id, password ) if canonical_user_id: @@ -773,7 +758,8 @@ class AuthHandler(BaseHandler): # unknown username or invalid password. self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), + qualified_user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, @@ -781,10 +767,7 @@ class AuthHandler(BaseHandler): # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. - raise LoginError( - 403, "Invalid password", - errcode=Codes.FORBIDDEN - ) + raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks def check_password_provider_3pid(self, medium, address, password): @@ -810,9 +793,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth( - medium, address, password, - ) + result = yield provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -853,8 +834,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, - device_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) @defer.inlineCallbacks @@ -896,12 +876,13 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: yield self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"], ) + str(user_info["user"]), (user_info["token_id"],) ) @defer.inlineCallbacks - def delete_access_tokens_for_user(self, user_id, except_token_id=None, - device_id=None): + def delete_access_tokens_for_user( + self, user_id, except_token_id=None, device_id=None + ): """Invalidate access tokens belonging to a user Args: @@ -915,7 +896,7 @@ class AuthHandler(BaseHandler): Deferred """ tokens_and_devices = yield self.store.user_delete_access_tokens( - user_id, except_token_id=except_token_id, device_id=device_id, + user_id, except_token_id=except_token_id, device_id=device_id ) # see if any of our auth providers want to know about this @@ -923,14 +904,12 @@ class AuthHandler(BaseHandler): if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( - user_id=user_id, - device_id=device_id, - access_token=token, + user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens yield self.hs.get_pusherpool().remove_pushers_by_access_token( - user_id, (token_id for _, token_id, _ in tokens_and_devices), + user_id, (token_id for _, token_id, _ in tokens_and_devices) ) @defer.inlineCallbacks @@ -944,12 +923,11 @@ class AuthHandler(BaseHandler): # of specific types of threepid (and fixes the fact that checking # for the presence of an email address during password reset was # case sensitive). - if medium == 'email': + if medium == "email": address = address.lower() yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() + user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) @defer.inlineCallbacks @@ -973,22 +951,15 @@ class AuthHandler(BaseHandler): """ # 'Canonicalise' email addresses as per above - if medium == 'email': + if medium == "email": address = address.lower() identity_handler = self.hs.get_handlers().identity_handler result = yield identity_handler.try_unbind_threepid( - user_id, - { - 'medium': medium, - 'address': address, - 'id_server': id_server, - }, + user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid( - user_id, medium, address, - ) + yield self.store.user_delete_threepid(user_id, medium, address) defer.returnValue(result) def _save_session(self, session): @@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler): Returns: Deferred(unicode): Hashed password. """ + def _do_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), - ).decode('ascii') + ).decode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) @@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler): Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ + def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + stored_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): - stored_hash = stored_hash.encode('ascii') + stored_hash = stored_hash.encode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: @@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler): for this user is too high too proceed. """ self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, update=True, @@ -1083,9 +1058,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) + macaroon.add_first_party_caveat( + "nonce = %s" % (stringutils.random_string_with_symbols(16),) + ) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() @@ -1116,7 +1091,8 @@ class MacaroonGenerator(object): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon |