summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/registration.py70
1 files changed, 32 insertions, 38 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 7e7d32eb66..19cb3b31c6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         self.get_user_by_id.invalidate((user_id,))
 
     @defer.inlineCallbacks
-    def user_delete_access_tokens(self, user_id, except_token_ids=[],
+    def user_delete_access_tokens(self, user_id, except_token_id=None,
                                   device_id=None,
                                   delete_refresh_tokens=False):
         """
@@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         Args:
             user_id (str):  ID of user the tokens belong to
-            except_token_ids (list[str]): list of access_tokens which should
+            except_token_id (str): list of access_tokens IDs which should
                 *not* be deleted
             device_id (str|None):  ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
@@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         Returns:
             defer.Deferred:
         """
-        def f(txn, table, except_tokens, call_after_delete):
-            sql = "SELECT token FROM %s WHERE user_id = ?" % table
-            clauses = [user_id]
-
+        def f(txn):
+            keyvalues = {
+                "user_id": user_id,
+            }
             if device_id is not None:
-                sql += " AND device_id = ?"
-                clauses.append(device_id)
+                keyvalues["device_id"] = device_id
 
-            if except_tokens:
-                sql += " AND id NOT IN (%s)" % (
-                    ",".join(["?" for _ in except_tokens]),
+            if delete_refresh_tokens:
+                self._simple_delete_txn(
+                    txn,
+                    table="refresh_tokens",
+                    keyvalues=keyvalues,
                 )
-                clauses += except_tokens
-
-            txn.execute(sql, clauses)
 
-            rows = txn.fetchall()
+            items = keyvalues.items()
+            where_clause = " AND ".join(k + " = ?" for k, _ in items)
+            values = [v for _, v in items]
+            if except_token_id:
+                where_clause += " AND id != ?"
+                values.append(except_token_id)
 
-            n = 100
-            chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
-            for chunk in chunks:
-                if call_after_delete:
-                    for row in chunk:
-                        txn.call_after(call_after_delete, (row[0],))
+            txn.execute(
+                "SELECT token FROM access_tokens WHERE %s" % where_clause,
+                values
+            )
+            rows = self.cursor_to_dict(txn)
 
-                txn.execute(
-                    "DELETE FROM %s WHERE token in (%s)" % (
-                        table,
-                        ",".join(["?" for _ in chunk]),
-                    ), [r[0] for r in chunk]
+            for row in rows:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (row["token"],)
                 )
 
-        # delete refresh tokens first, to stop new access tokens being
-        # allocated while our backs are turned
-        if delete_refresh_tokens:
-            yield self.runInteraction(
-                "user_delete_access_tokens", f,
-                table="refresh_tokens",
-                except_tokens=[],
-                call_after_delete=None,
+            txn.execute(
+                "DELETE FROM access_tokens WHERE %s" % where_clause,
+                values
             )
 
         yield self.runInteraction(
             "user_delete_access_tokens", f,
-            table="access_tokens",
-            except_tokens=except_token_ids,
-            call_after_delete=self.get_user_by_access_token.invalidate,
         )
 
     def delete_access_token(self, access_token):
@@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 },
             )
 
-            txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_access_token, (access_token,)
+            )
 
         return self.runInteraction("delete_access_token", f)