summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/auth.py6
-rw-r--r--synapse/push/pusherpool.py8
-rw-r--r--synapse/replication/slave/storage/_base.py2
-rw-r--r--synapse/replication/slave/storage/registration.py2
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/registration.py70
6 files changed, 42 insertions, 47 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a582d6334b..6986930c0d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
     def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
 
-        except_access_token_ids = [requester.access_token_id] if requester else []
+        except_access_token_id = requester.access_token_id if requester else None
 
         try:
             yield self.store.user_set_password_hash(user_id, password_hash)
@@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
                 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
             raise e
         yield self.store.user_delete_access_tokens(
-            user_id, except_access_token_ids
+            user_id, except_access_token_id
         )
         yield self.hs.get_pusherpool().remove_pushers_by_user(
-            user_id, except_access_token_ids
+            user_id, except_access_token_id
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 5853ec36a9..54c0f1b849 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -102,14 +102,14 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id, except_token_ids=[]):
+    def remove_pushers_by_user(self, user_id, except_access_token_id=None):
         all = yield self.store.get_all_pushers()
         logger.info(
-            "Removing all pushers for user %s except access tokens ids %r",
-            user_id, except_token_ids
+            "Removing all pushers for user %s except access tokens id %r",
+            user_id, except_access_token_id
         )
         for p in all:
-            if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
+            if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d839d169ab..f19540d6bb 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -51,6 +51,6 @@ class BaseSlavedStore(SQLBaseStore):
                 try:
                     getattr(self, cache_func).invalidate(tuple(keys))
                 except AttributeError:
-                    logger.warn("Got unexpected cache_func: %r", cache_func)
+                    logger.info("Got unexpected cache_func: %r", cache_func)
             self._cache_id_gen.advance(int(stream["position"]))
         return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 307833f9e1..38b78b97fc 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -25,6 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
     # TODO: use the cached version and invalidate deleted tokens
     get_user_by_access_token = RegistrationStore.__dict__[
         "get_user_by_access_token"
-    ].orig
+    ]
 
     _query_for_auth = DataStore._query_for_auth.__func__
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b0923a9cad..0a2e78fd81 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -880,6 +880,7 @@ class SQLBaseStore(object):
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
             txn.call_after(ctx.__exit__, None, None, None)
+            txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             self._simple_insert_txn(
                 txn,
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 7e7d32eb66..19cb3b31c6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         self.get_user_by_id.invalidate((user_id,))
 
     @defer.inlineCallbacks
-    def user_delete_access_tokens(self, user_id, except_token_ids=[],
+    def user_delete_access_tokens(self, user_id, except_token_id=None,
                                   device_id=None,
                                   delete_refresh_tokens=False):
         """
@@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         Args:
             user_id (str):  ID of user the tokens belong to
-            except_token_ids (list[str]): list of access_tokens which should
+            except_token_id (str): list of access_tokens IDs which should
                 *not* be deleted
             device_id (str|None):  ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
@@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         Returns:
             defer.Deferred:
         """
-        def f(txn, table, except_tokens, call_after_delete):
-            sql = "SELECT token FROM %s WHERE user_id = ?" % table
-            clauses = [user_id]
-
+        def f(txn):
+            keyvalues = {
+                "user_id": user_id,
+            }
             if device_id is not None:
-                sql += " AND device_id = ?"
-                clauses.append(device_id)
+                keyvalues["device_id"] = device_id
 
-            if except_tokens:
-                sql += " AND id NOT IN (%s)" % (
-                    ",".join(["?" for _ in except_tokens]),
+            if delete_refresh_tokens:
+                self._simple_delete_txn(
+                    txn,
+                    table="refresh_tokens",
+                    keyvalues=keyvalues,
                 )
-                clauses += except_tokens
-
-            txn.execute(sql, clauses)
 
-            rows = txn.fetchall()
+            items = keyvalues.items()
+            where_clause = " AND ".join(k + " = ?" for k, _ in items)
+            values = [v for _, v in items]
+            if except_token_id:
+                where_clause += " AND id != ?"
+                values.append(except_token_id)
 
-            n = 100
-            chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
-            for chunk in chunks:
-                if call_after_delete:
-                    for row in chunk:
-                        txn.call_after(call_after_delete, (row[0],))
+            txn.execute(
+                "SELECT token FROM access_tokens WHERE %s" % where_clause,
+                values
+            )
+            rows = self.cursor_to_dict(txn)
 
-                txn.execute(
-                    "DELETE FROM %s WHERE token in (%s)" % (
-                        table,
-                        ",".join(["?" for _ in chunk]),
-                    ), [r[0] for r in chunk]
+            for row in rows:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (row["token"],)
                 )
 
-        # delete refresh tokens first, to stop new access tokens being
-        # allocated while our backs are turned
-        if delete_refresh_tokens:
-            yield self.runInteraction(
-                "user_delete_access_tokens", f,
-                table="refresh_tokens",
-                except_tokens=[],
-                call_after_delete=None,
+            txn.execute(
+                "DELETE FROM access_tokens WHERE %s" % where_clause,
+                values
             )
 
         yield self.runInteraction(
             "user_delete_access_tokens", f,
-            table="access_tokens",
-            except_tokens=except_token_ids,
-            call_after_delete=self.get_user_by_access_token.invalidate,
         )
 
     def delete_access_token(self, access_token):
@@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 },
             )
 
-            txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_access_token, (access_token,)
+            )
 
         return self.runInteraction("delete_access_token", f)