summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/auth.py13
-rw-r--r--synapse/push/pusherpool.py8
-rw-r--r--synapse/rest/client/v2_alpha/account.py2
-rw-r--r--synapse/storage/registration.py28
4 files changed, 34 insertions, 17 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7a4afe446d..a740cc3da3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -432,13 +432,18 @@ class AuthHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword):
+    def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
 
+        except_access_token_ids = [requester.access_token_id] if requester else []
+
         yield self.store.user_set_password_hash(user_id, password_hash)
-        yield self.store.user_delete_access_tokens(user_id)
-        yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
-        yield self.store.flush_user(user_id)
+        yield self.store.user_delete_access_tokens_except(
+                user_id, except_access_token_ids
+        )
+        yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens(
+                user_id, except_access_token_ids
+        )
 
     @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 772a095f8b..28ec94d866 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -92,14 +92,14 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id):
+    def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids):
         all = yield self.store.get_all_pushers()
         logger.info(
-            "Removing all pushers for user %s",
-            user_id,
+            "Removing all pushers for user %s except access tokens ids %r",
+            user_id, except_token_ids
         )
         for p in all:
-            if p['user_name'] == user_id:
+            if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 688b051580..dd4ea45588 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
         new_password = params['new_password']
 
         yield self.auth_handler.set_password(
-            user_id, new_password
+            user_id, new_password, requester
         )
 
         defer.returnValue((200, {}))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index aa49f53458..5eef7ebcc7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -208,14 +208,26 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def flush_user(self, user_id):
-        rows = yield self._execute(
-            'flush_user', None,
-            "SELECT token FROM access_tokens WHERE user_id = ?",
-            user_id
-        )
-        for r in rows:
-            self.get_user_by_access_token.invalidate((r,))
+    def user_delete_access_tokens_except(self, user_id, except_token_ids):
+        def f(txn):
+            txn.execute(
+                "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50",
+                    (user_id,)
+            )
+            rows = txn.fetchall()
+            for r in rows:
+                if r[0] in except_token_ids:
+                    continue
+
+                txn.call_after(self.get_user_by_access_token.invalidate, (r[1],))
+            txn.execute(
+                "DELETE FROM access_tokens WHERE id in (%s)" % ",".join(
+                    ["?" for _ in rows]
+                ), [r[0] for r in rows]
+            )
+            return len(rows) == 50
+        while (yield self.runInteraction("user_delete_access_tokens_except", f)):
+            pass
 
     @cached()
     def get_user_by_access_token(self, token):