summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml2
-rw-r--r--changelog.d/4644.misc1
-rw-r--r--synapse/storage/_base.py146
-rw-r--r--tests/storage/test__base.py88
4 files changed, 224 insertions, 13 deletions
diff --git a/.travis.yml b/.travis.yml
index d88f10324f..5d763123a0 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -89,7 +89,7 @@ install:
   - psql -At -U postgres -c 'select version();' || true
 
   - pip install tox
-  
+
   # if we don't have python3.6 in this environment, travis unhelpfully gives us
   # a `python3.6` on our path which does nothing but spit out a warning. Tox
   # tries to run it (even if we're not running a py36 env), so the build logs
diff --git a/changelog.d/4644.misc b/changelog.d/4644.misc
new file mode 100644
index 0000000000..84137c3412
--- /dev/null
+++ b/changelog.d/4644.misc
@@ -0,0 +1 @@
+Introduce upsert batching functionality in the database layer.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f1a5366b95..3d895da43c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -106,6 +106,14 @@ class LoggingTransaction(object):
     def __iter__(self):
         return self.txn.__iter__()
 
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -699,20 +707,34 @@ class SQLBaseStore(object):
             else:
                 return "%s = ?" % (key,)
 
-        # First try to update.
-        sql = "UPDATE %s SET %s WHERE %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in values),
-            " AND ".join(_getwhere(k) for k in keyvalues)
-        )
-        sqlargs = list(values.values()) + list(keyvalues.values())
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
 
-        txn.execute(sql, sqlargs)
-        if txn.rowcount > 0:
-            # successfully updated at least one row.
-            return False
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
 
-        # We didn't update any rows so insert a new one
+        # We didn't find any existing rows, so insert a new one
         allvalues = {}
         allvalues.update(keyvalues)
         allvalues.update(values)
@@ -759,6 +781,106 @@ class SQLBaseStore(object):
         )
         txn.execute(sql, list(allvalues.values()))
 
+    def _simple_upsert_many_txn(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self._simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def _simple_upsert_many_txn_emulated(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def _simple_upsert_many_txn_native_upsert(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        allnames = []
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = (
+                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 52eb05bfbf..dd49a14524 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -314,3 +315,90 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 2)
         self.assertEquals(callcount2[0], 3)
+
+
+class UpsertManyTests(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.storage = hs.get_datastore()
+
+        self.table_name = "table_" + hs.get_secrets().token_hex(6)
+        self.get_success(
+            self.storage.runInteraction(
+                "create",
+                lambda x, *a: x.execute(*a),
+                "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
+                % (self.table_name,),
+            )
+        )
+        self.get_success(
+            self.storage.runInteraction(
+                "index",
+                lambda x, *a: x.execute(*a),
+                "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
+                % (self.table_name, self.table_name),
+            )
+        )
+
+    def _dump_to_tuple(self, res):
+        for i in res:
+            yield (i["id"], i["username"], i["value"])
+
+    def test_upsert_many(self):
+        """
+        Upsert_many will perform the upsert operation across a batch of data.
+        """
+        # Add some data to an empty table
+        key_names = ["id", "username"]
+        value_names = ["value"]
+        key_values = [[1, "user1"], [2, "user2"]]
+        value_values = [["hello"], ["there"]]
+
+        self.get_success(
+            self.storage.runInteraction(
+                "test",
+                self.storage._simple_upsert_many_txn,
+                self.table_name,
+                key_names,
+                key_values,
+                value_names,
+                value_values,
+            )
+        )
+
+        # Check results are what we expect
+        res = self.get_success(
+            self.storage._simple_select_list(
+                self.table_name, None, ["id, username, value"]
+            )
+        )
+        self.assertEqual(
+            set(self._dump_to_tuple(res)),
+            set([(1, "user1", "hello"), (2, "user2", "there")]),
+        )
+
+        # Update only user2
+        key_values = [[2, "user2"]]
+        value_values = [["bleb"]]
+
+        self.get_success(
+            self.storage.runInteraction(
+                "test",
+                self.storage._simple_upsert_many_txn,
+                self.table_name,
+                key_names,
+                key_values,
+                value_names,
+                value_values,
+            )
+        )
+
+        # Check results are what we expect
+        res = self.get_success(
+            self.storage._simple_select_list(
+                self.table_name, None, ["id, username, value"]
+            )
+        )
+        self.assertEqual(
+            set(self._dump_to_tuple(res)),
+            set([(1, "user1", "hello"), (2, "user2", "bleb")]),
+        )