diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7a4afe446d..a740cc3da3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -432,13 +432,18 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword):
+ def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
+ except_access_token_ids = [requester.access_token_id] if requester else []
+
yield self.store.user_set_password_hash(user_id, password_hash)
- yield self.store.user_delete_access_tokens(user_id)
- yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
- yield self.store.flush_user(user_id)
+ yield self.store.user_delete_access_tokens_except(
+ user_id, except_access_token_ids
+ )
+ yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens(
+ user_id, except_access_token_ids
+ )
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 772a095f8b..28ec94d866 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -92,14 +92,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
- def remove_pushers_by_user(self, user_id):
+ def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids):
all = yield self.store.get_all_pushers()
logger.info(
- "Removing all pushers for user %s",
- user_id,
+ "Removing all pushers for user %s except access tokens ids %r",
+ user_id, except_token_ids
)
for p in all:
- if p['user_name'] == user_id:
+ if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 688b051580..dd4ea45588 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
new_password = params['new_password']
yield self.auth_handler.set_password(
- user_id, new_password
+ user_id, new_password, requester
)
defer.returnValue((200, {}))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index aa49f53458..5eef7ebcc7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -208,14 +208,26 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def flush_user(self, user_id):
- rows = yield self._execute(
- 'flush_user', None,
- "SELECT token FROM access_tokens WHERE user_id = ?",
- user_id
- )
- for r in rows:
- self.get_user_by_access_token.invalidate((r,))
+ def user_delete_access_tokens_except(self, user_id, except_token_ids):
+ def f(txn):
+ txn.execute(
+ "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50",
+ (user_id,)
+ )
+ rows = txn.fetchall()
+ for r in rows:
+ if r[0] in except_token_ids:
+ continue
+
+ txn.call_after(self.get_user_by_access_token.invalidate, (r[1],))
+ txn.execute(
+ "DELETE FROM access_tokens WHERE id in (%s)" % ",".join(
+ ["?" for _ in rows]
+ ), [r[0] for r in rows]
+ )
+ return len(rows) == 50
+ while (yield self.runInteraction("user_delete_access_tokens_except", f)):
+ pass
@cached()
def get_user_by_access_token(self, token):
|