summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13822.misc1
-rw-r--r--synapse/storage/background_updates.py1
-rw-r--r--synapse/storage/database.py30
3 files changed, 25 insertions, 7 deletions
diff --git a/changelog.d/13822.misc b/changelog.d/13822.misc
new file mode 100644
index 0000000000..dbc77cbcfa
--- /dev/null
+++ b/changelog.d/13822.misc
@@ -0,0 +1 @@
+Support providing an index predicate clause when doing upserts.
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index cf1eabc437..bf5e7ee7be 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -533,6 +533,7 @@ class BackgroundUpdater:
             index_name: name of index to add
             table: table to add index to
             columns: columns/expressions to include in index
+            where_clause: A WHERE clause to specify a partial unique index.
             unique: true to make a UNIQUE index
             psql_only: true to only create this index on psql databases (useful
                 for virtual sqlite tables)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e881bff7fb..921cd4dc5e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1191,6 +1191,7 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
         insertion_values: Optional[Dict[str, Any]] = None,
+        where_clause: Optional[str] = None,
         lock: bool = True,
     ) -> bool:
         """
@@ -1203,6 +1204,7 @@ class DatabasePool:
             keyvalues: The unique key tables and their new values
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
+            where_clause: An index predicate to apply to the upsert.
             lock: True to lock the table when doing the upsert. Unused when performing
                 a native upsert.
         Returns:
@@ -1213,7 +1215,12 @@ class DatabasePool:
 
         if table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_txn_native_upsert(
-                txn, table, keyvalues, values, insertion_values=insertion_values
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                where_clause=where_clause,
             )
         else:
             return self.simple_upsert_txn_emulated(
@@ -1222,6 +1229,7 @@ class DatabasePool:
                 keyvalues,
                 values,
                 insertion_values=insertion_values,
+                where_clause=where_clause,
                 lock=lock,
             )
 
@@ -1232,6 +1240,7 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
         insertion_values: Optional[Dict[str, Any]] = None,
+        where_clause: Optional[str] = None,
         lock: bool = True,
     ) -> bool:
         """
@@ -1240,6 +1249,7 @@ class DatabasePool:
             keyvalues: The unique key tables and their new values
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
+            where_clause: An index predicate to apply to the upsert.
             lock: True to lock the table when doing the upsert.
         Returns:
             Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ class DatabasePool:
             else:
                 return "%s = ?" % (key,)
 
+        # Generate a where clause of each keyvalue and optionally the provided
+        # index predicate.
+        where = [_getwhere(k) for k in keyvalues]
+        if where_clause:
+            where.append(where_clause)
+
         if not values:
             # If `values` is empty, then all of the values we care about are in
             # the unique key, so there is nothing to UPDATE. We can just do a
             # SELECT instead to see if it exists.
-            sql = "SELECT 1 FROM %s WHERE %s" % (
-                table,
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
+            sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
             sqlargs = list(keyvalues.values())
             txn.execute(sql, sqlargs)
             if txn.fetchall():
@@ -1277,7 +1290,7 @@ class DatabasePool:
             sql = "UPDATE %s SET %s WHERE %s" % (
                 table,
                 ", ".join("%s = ?" % (k,) for k in values),
-                " AND ".join(_getwhere(k) for k in keyvalues),
+                " AND ".join(where),
             )
             sqlargs = list(values.values()) + list(keyvalues.values())
 
@@ -1307,6 +1320,7 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
         insertion_values: Optional[Dict[str, Any]] = None,
+        where_clause: Optional[str] = None,
     ) -> bool:
         """
         Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ class DatabasePool:
             keyvalues: The unique key tables and their new values
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
+            where_clause: An index predicate to apply to the upsert.
 
         Returns:
             Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ class DatabasePool:
             allvalues.update(values)
             latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
 
-        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
             ", ".join(k for k in keyvalues),
+            f"WHERE {where_clause}" if where_clause else "",
             latter,
         )
         txn.execute(sql, list(allvalues.values()))