summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py85
1 files changed, 70 insertions, 15 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 62e82a2570..82d458b424 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
 
 
 class AuthHandler(BaseHandler):
+    SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs):
         super(AuthHandler, self).__init__(hs)
@@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
                         'auth' key: this method prompts for auth if none is sent.
             clientip (str): The IP address of the client.
         Returns:
-            A tuple of (authed, dict, dict) where authed is true if the client
-            has successfully completed an auth flow. If it is true, the first
-            dict contains the authenticated credentials of each stage.
+            A tuple of (authed, dict, dict, session_id) where authed is true if
+            the client has successfully completed an auth flow. If it is true
+            the first dict contains the authenticated credentials of each stage.
 
             If authed is false, the first dictionary is the server response to
             the login request and should be passed back to the client.
 
             In either case, the second dict contains the parameters for this
             request (which may have been given only in a previous call).
+
+            session_id is the ID of this session, either passed in by the client
+            or assigned by the call to check_auth
         """
 
         authdict = None
@@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
 
         if not authdict:
             defer.returnValue(
-                (False, self._auth_dict_for_flows(flows, session), clientdict)
+                (
+                    False, self._auth_dict_for_flows(flows, session),
+                    clientdict, session['id']
+                )
             )
 
         if 'creds' not in session:
@@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
         for f in flows:
             if len(set(f) - set(creds.keys())) == 0:
                 logger.info("Auth completed with creds: %r", creds)
-                self._remove_session(session)
-                defer.returnValue((True, creds, clientdict))
+                defer.returnValue((True, creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
-        defer.returnValue((False, ret, clientdict))
+        defer.returnValue((False, ret, clientdict, session['id']))
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -154,6 +160,43 @@ class AuthHandler(BaseHandler):
             defer.returnValue(True)
         defer.returnValue(False)
 
+    def get_session_id(self, clientdict):
+        """
+        Gets the session ID for a client given the client dictionary
+        :param clientdict: The dictionary sent by the client in the request
+        :return: The string session ID the client sent. If the client did not
+                 send a session ID, returns None.
+        """
+        sid = None
+        if clientdict and 'auth' in clientdict:
+            authdict = clientdict['auth']
+            if 'session' in authdict:
+                sid = authdict['session']
+        return sid
+
+    def set_session_data(self, session_id, key, value):
+        """
+        Store a key-value pair into the sessions data associated with this
+        request. This data is stored server-side and cannot be modified by
+        the client.
+        :param session_id: (string) The ID of this session as returned from check_auth
+        :param key: (string) The key to store the data under
+        :param value: (any) The data to store
+        """
+        sess = self._get_session_info(session_id)
+        sess.setdefault('serverdict', {})[key] = value
+        self._save_session(sess)
+
+    def get_session_data(self, session_id, key, default=None):
+        """
+        Retrieve data stored with set_session_data
+        :param session_id: (string) The ID of this session as returned from check_auth
+        :param key: (string) The key to store the data under
+        :param default: (any) Value to return if the key has not been set
+        """
+        sess = self._get_session_info(session_id)
+        return sess.setdefault('serverdict', {}).get(key, default)
+
     @defer.inlineCallbacks
     def _check_password_auth(self, authdict, _):
         if "user" not in authdict or "password" not in authdict:
@@ -432,13 +475,18 @@ class AuthHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword):
+    def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
 
+        except_access_token_ids = [requester.access_token_id] if requester else []
+
         yield self.store.user_set_password_hash(user_id, password_hash)
-        yield self.store.user_delete_access_tokens(user_id)
-        yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
-        yield self.store.flush_user(user_id)
+        yield self.store.user_delete_access_tokens(
+            user_id, except_access_token_ids
+        )
+        yield self.hs.get_pusherpool().remove_pushers_by_user(
+            user_id, except_access_token_ids
+        )
 
     @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
@@ -450,11 +498,18 @@ class AuthHandler(BaseHandler):
     def _save_session(self, session):
         # TODO: Persistent storage
         logger.debug("Saving session %s", session)
+        session["last_used"] = self.hs.get_clock().time_msec()
         self.sessions[session["id"]] = session
+        self._prune_sessions()
 
-    def _remove_session(self, session):
-        logger.debug("Removing session %s", session)
-        del self.sessions[session["id"]]
+    def _prune_sessions(self):
+        for sid, sess in self.sessions.items():
+            last_used = 0
+            if 'last_used' in sess:
+                last_used = sess['last_used']
+            now = self.hs.get_clock().time_msec()
+            if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
+                del self.sessions[sid]
 
     def hash(self, password):
         """Computes a secure hash of password.
@@ -477,4 +532,4 @@ class AuthHandler(BaseHandler):
         Returns:
             Whether self.hash(password) == stored_hash (bool).
         """
-        return bcrypt.checkpw(password, stored_hash)
+        return bcrypt.hashpw(password, stored_hash) == stored_hash