summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py62
1 files changed, 49 insertions, 13 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2da3dd1b1..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -206,18 +206,23 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
-    def execute(self, sql, *args, **kwargs):
+    def execute(self, sql, *args):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql, *args):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _do_execute(self, func, sql, *args):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
         sql = self.database_engine.convert_param_style(sql)
 
-        if args and args[0]:
+        if args:
             try:
                 sql_logger.debug(
-                    "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
-                    self.name,
-                    *args[0]
+                    "[SQL values] {%s} %r",
+                    self.name, args[0]
                 )
             except:
                 # Don't let logging failures stop SQL from working
@@ -226,8 +231,8 @@ class LoggingTransaction(object):
         start = time.time() * 1000
 
         try:
-            return self.txn.execute(
-                sql, *args, **kwargs
+            return func(
+                sql, *args
             )
         except Exception as e:
             logger.debug("[SQL FAIL] {%s} %s", self.name, e)
@@ -484,18 +489,49 @@ class SQLBaseStore(object):
 
     @log_function
     def _simple_insert_txn(self, txn, table, values):
+        keys, vals = zip(*values.items())
+
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
-            ", ".join(k for k in values),
-            ", ".join("?" for k in values)
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys)
         )
 
-        logger.debug(
-            "[SQL] %s Args=%s",
-            sql, values.values(),
+        txn.execute(sql, vals)
+
+    def _simple_insert_many_txn(self, txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(*[
+            zip(
+                *(sorted(i.items(), key=lambda kv: kv[0]))
+            )
+            for i in values
+            if i
+        ])
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError(
+                    "All items must have the same keys"
+                )
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0])
         )
 
-        txn.execute(sql, values.values())
+        txn.executemany(sql, vals)
 
     def _simple_upsert(self, table, keyvalues, values,
                        insertion_values={}, desc="_simple_upsert", lock=True):