diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 146 |
1 files changed, 134 insertions, 12 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f1a5366b95..3d895da43c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -106,6 +106,14 @@ class LoggingTransaction(object): def __iter__(self): return self.txn.__iter__() + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + def execute(self, sql, *args): self._do_execute(self.txn.execute, sql, *args) @@ -699,20 +707,34 @@ class SQLBaseStore(object): else: return "%s = ?" % (key,) - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues) - ) - sqlargs = list(values.values()) + list(keyvalues.values()) + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues) + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues) + ) + sqlargs = list(values.values()) + list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - # We didn't update any rows so insert a new one + # We didn't find any existing rows, so insert a new one allvalues = {} allvalues.update(keyvalues) allvalues.update(values) @@ -759,6 +781,106 @@ class SQLBaseStore(object): ) txn.execute(sql, list(allvalues.values())) + def _simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self._simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self._simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def _simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self._simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def _simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = ( + "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names) + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to |