summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrewm@element.io>2022-07-26 12:07:40 +0100
committerAndrew Morgan <andrewm@element.io>2022-08-03 14:39:43 +0100
commit7496a37e03333c523d73a20ad7ed3d07e5a9e0ce (patch)
treecee730caaef50a3134ed62a96049f2571c2142e9
parentRemove unused 'DataStore.get_users' method (diff)
downloadsynapse-7496a37e03333c523d73a20ad7ed3d07e5a9e0ce.tar.xz
Allow deleting all rows when passing empty keyvalues dict to simple_delete{_many}
-rw-r--r--synapse/storage/database.py32
1 files changed, 19 insertions, 13 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b394a6658b..6e4cdfdd89 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2092,7 +2092,9 @@ class DatabasePool:
 
         Args:
             table: string giving the table name
-            keyvalues: dict of column names and values to select the row with
+            keyvalues: dict of column names and values to select the row with. If an
+                empty dict is passed then no selection clauses are applied. Therefore,
+                ALL rows will be deleted.
             desc: description of the transaction, for logging and metrics
 
         Returns:
@@ -2112,15 +2114,17 @@ class DatabasePool:
 
         Args:
             table: string giving the table name
-            keyvalues: dict of column names and values to select the row with
+            keyvalues: dict of column names and values to select the row with. If an
+                empty dict is passed then no selection clauses are applied. Therefore,
+                ALL rows will be deleted.
 
         Returns:
             The number of deleted rows.
         """
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
+        sql = "DELETE FROM %s" % (table,)
+
+        if keyvalues:
+            sql += " WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
 
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
@@ -2135,18 +2139,19 @@ class DatabasePool:
     ) -> int:
         """Executes a DELETE query on the named table.
 
-        Filters rows by if value of `column` is in `iterable`.
+        Filters rows if value of `column` is in `iterable`.
 
         Args:
             table: string giving the table name
             column: column name to test for inclusion against `iterable`
             iterable: list of values to match against `column`. NB cannot be a generator
                 as it may be evaluated multiple times.
-            keyvalues: dict of column names and values to select the rows with
+            keyvalues: dict of column names and values to select the rows with. If an
+                emtpy dict is passed, this option will have no effect.
             desc: description of the transaction, for logging and metrics
 
         Returns:
-            Number rows deleted
+            The number of deleted rows.
         """
         return await self.runInteraction(
             desc,
@@ -2178,10 +2183,11 @@ class DatabasePool:
             column: column name to test for inclusion against `values`
             values: values of `column` which choose rows to delete
             keyvalues: dict of extra column names and values to select the rows
-                with. They will be ANDed together with the main predicate.
+                with. They will be ANDed together with the main predicate. If an
+                empty dict is passed, this option will have no effect.
 
         Returns:
-            Number rows deleted
+            The number of deleted rows.
         """
         if not values:
             return 0
@@ -2195,8 +2201,8 @@ class DatabasePool:
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+
         txn.execute(sql, values)
 
         return txn.rowcount