summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py146
1 files changed, 134 insertions, 12 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f1a5366b95..3d895da43c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -106,6 +106,14 @@ class LoggingTransaction(object):
     def __iter__(self):
         return self.txn.__iter__()
 
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -699,20 +707,34 @@ class SQLBaseStore(object):
             else:
                 return "%s = ?" % (key,)
 
-        # First try to update.
-        sql = "UPDATE %s SET %s WHERE %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in values),
-            " AND ".join(_getwhere(k) for k in keyvalues)
-        )
-        sqlargs = list(values.values()) + list(keyvalues.values())
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
 
-        txn.execute(sql, sqlargs)
-        if txn.rowcount > 0:
-            # successfully updated at least one row.
-            return False
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
 
-        # We didn't update any rows so insert a new one
+        # We didn't find any existing rows, so insert a new one
         allvalues = {}
         allvalues.update(keyvalues)
         allvalues.update(values)
@@ -759,6 +781,106 @@ class SQLBaseStore(object):
         )
         txn.execute(sql, list(allvalues.values()))
 
+    def _simple_upsert_many_txn(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self._simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def _simple_upsert_many_txn_emulated(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def _simple_upsert_many_txn_native_upsert(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        allnames = []
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = (
+                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to