diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 113 |
1 files changed, 81 insertions, 32 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1ecf7fef17..98d99dd0a8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -47,17 +47,24 @@ class AuthHandler(BaseHandler): self.sessions = {} @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip=None): + def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration protocol and handles the login flow. + As a side effect, this function fills in the 'creds' key on the user's + session with a map, which maps each auth-type (str) to the relevant + identity authenticated by that auth-type (mostly str, but for captcha, bool). + Args: - flows: list of list of stages - authdict: The dictionary from the client root level, not the - 'auth' key: this method prompts for auth if none is sent. + flows (list): A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the + 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. Returns: - A tuple of authed, dict, dict where authed is true if the client + A tuple of (authed, dict, dict) where authed is true if the client has successfully completed an auth flow. If it is true, the first dict contains the authenticated credentials of each stage. @@ -75,7 +82,7 @@ class AuthHandler(BaseHandler): del clientdict['auth'] if 'session' in authdict: sid = authdict['session'] - sess = self._get_session_info(sid) + session = self._get_session_info(sid) if len(clientdict) > 0: # This was designed to allow the client to omit the parameters @@ -87,20 +94,19 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - sess['clientdict'] = clientdict - self._save_session(sess) - pass - elif 'clientdict' in sess: - clientdict = sess['clientdict'] + session['clientdict'] = clientdict + self._save_session(session) + elif 'clientdict' in session: + clientdict = session['clientdict'] if not authdict: defer.returnValue( - (False, self._auth_dict_for_flows(flows, sess), clientdict) + (False, self._auth_dict_for_flows(flows, session), clientdict) ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + if 'creds' not in session: + session['creds'] = {} + creds = session['creds'] # check auth type currently being presented if 'type' in authdict: @@ -109,15 +115,15 @@ class AuthHandler(BaseHandler): result = yield self.checkers[authdict['type']](authdict, clientip) if result: creds[authdict['type']] = result - self._save_session(sess) + self._save_session(session) for f in flows: if len(set(f) - set(creds.keys())) == 0: logger.info("Auth completed with creds: %r", creds) - self._remove_session(sess) + self._remove_session(session) defer.returnValue((True, creds, clientdict)) - ret = self._auth_dict_for_flows(flows, sess) + ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() defer.returnValue((False, ret, clientdict)) @@ -151,22 +157,13 @@ class AuthHandler(BaseHandler): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - user = authdict["user"] + user_id = authdict["user"] password = authdict["password"] - if not user.startswith('@'): - user = UserID.create(user, self.hs.hostname).to_string() + if not user_id.startswith('@'): + user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_info = yield self.store.get_user_by_id(user_id=user) - if not user_info: - logger.warn("Attempted to login as %s but they do not exist", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - stored_hash = user_info["password_hash"] - if bcrypt.checkpw(password, stored_hash): - defer.returnValue(user) - else: - logger.warn("Failed password login for user %s", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + self._check_password(user_id, password) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -270,6 +267,58 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] + @defer.inlineCallbacks + def login_with_password(self, user_id, password): + """ + Authenticates the user with their username and password. + + Used only by the v1 login API. + + Args: + user_id (str): User ID + password (str): Password + Returns: + The access token for the user's session. + Raises: + StoreError if there was a problem storing the token. + LoginError if there was an authentication problem. + """ + self._check_password(user_id, password) + + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_token(user_id) + logger.info("Adding token %s for user %s", access_token, user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + def _check_password(self, user_id, password): + """Checks that user_id has passed password, raises LoginError if not.""" + user_info = yield self.store.get_user_by_id(user_id=user_id) + if not user_info: + logger.warn("Attempted to login as %s but they do not exist", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + stored_hash = user_info["password_hash"] + if not bcrypt.checkpw(password, stored_hash): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword): + password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) + + yield self.store.user_set_password_hash(user_id, password_hash) + yield self.store.user_delete_access_tokens(user_id) + yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) + yield self.store.flush_user(user_id) + + @defer.inlineCallbacks + def add_threepid(self, user_id, medium, address, validated_at): + yield self.store.user_add_threepid( + user_id, medium, address, validated_at, + self.hs.get_clock().time_msec() + ) + def _save_session(self, session): # TODO: Persistent storage logger.debug("Saving session %s", session) |