diff --git a/changelog.d/13822.misc b/changelog.d/13822.misc
new file mode 100644
index 0000000000..dbc77cbcfa
--- /dev/null
+++ b/changelog.d/13822.misc
@@ -0,0 +1 @@
+Support providing an index predicate clause when doing upserts.
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index cf1eabc437..bf5e7ee7be 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -533,6 +533,7 @@ class BackgroundUpdater:
index_name: name of index to add
table: table to add index to
columns: columns/expressions to include in index
+ where_clause: A WHERE clause to specify a partial unique index.
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e881bff7fb..921cd4dc5e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1191,6 +1191,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
+ where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1203,6 +1204,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
+ where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
Returns:
@@ -1213,7 +1215,12 @@ class DatabasePool:
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
- txn, table, keyvalues, values, insertion_values=insertion_values
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ where_clause=where_clause,
)
else:
return self.simple_upsert_txn_emulated(
@@ -1222,6 +1229,7 @@ class DatabasePool:
keyvalues,
values,
insertion_values=insertion_values,
+ where_clause=where_clause,
lock=lock,
)
@@ -1232,6 +1240,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
+ where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1240,6 +1249,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
+ where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ class DatabasePool:
else:
return "%s = ?" % (key,)
+ # Generate a where clause of each keyvalue and optionally the provided
+ # index predicate.
+ where = [_getwhere(k) for k in keyvalues]
+ if where_clause:
+ where.append(where_clause)
+
if not values:
# If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists.
- sql = "SELECT 1 FROM %s WHERE %s" % (
- table,
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
+ sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.fetchall():
@@ -1277,7 +1290,7 @@ class DatabasePool:
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues),
+ " AND ".join(where),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@@ -1307,6 +1320,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
+ where_clause: Optional[str] = None,
) -> bool:
"""
Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
+ where_clause: An index predicate to apply to the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ class DatabasePool:
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
+ f"WHERE {where_clause}" if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))
|