summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-01-28 17:29:30 +0000
committerMark Haines <mark.haines@matrix.org>2015-01-28 17:29:30 +0000
commit9c61556504924a063e7739ba80122c8753aed50c (patch)
treea483e292dfe602f30be1f9f17266a1bd0f88552d /synapse/storage/_base.py
parentFix Formatting (diff)
parentMerge pull request #36 from matrix-org/device_id_from_access_token (diff)
downloadsynapse-9c61556504924a063e7739ba80122c8753aed50c.tar.xz
Merge branch 'develop' into client_v2_sync
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py48
1 files changed, 46 insertions, 2 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f660fc6eaf..4e8bd3faa9 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -193,6 +193,50 @@ class SQLBaseStore(object):
         txn.execute(sql, values.values())
         return txn.lastrowid
 
+    def _simple_upsert(self, table, keyvalues, values):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+        Returns: A deferred
+        """
+        return self.runInteraction(
+            "_simple_upsert",
+            self._simple_upsert_txn, table, keyvalues, values
+        )
+
+    def _simple_upsert_txn(self, txn, table, keyvalues, values):
+        # Try to update
+        sql = "UPDATE %s SET %s WHERE %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in values),
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+        )
+        sqlargs = values.values() + keyvalues.values()
+        logger.debug(
+            "[SQL] %s Args=%s",
+            sql, sqlargs,
+        )
+
+        txn.execute(sql, sqlargs)
+        if txn.rowcount == 0:
+            # We didn't update and rows so insert a new one
+            allvalues = {}
+            allvalues.update(keyvalues)
+            allvalues.update(values)
+
+            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+                table,
+                ", ".join(k for k in allvalues),
+                ", ".join("?" for _ in allvalues)
+            )
+            logger.debug(
+                "[SQL] %s Args=%s",
+                sql, keyvalues.values(),
+            )
+            txn.execute(sql, allvalues.values())
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False):
         """Executes a SELECT query on the named table, which is expected to
@@ -344,8 +388,8 @@ class SQLBaseStore(object):
         if updatevalues:
             update_sql = "UPDATE %s SET %s WHERE %s" % (
                 table,
-                ", ".join("%s = ?" % (k) for k in updatevalues),
-                " AND ".join("%s = ?" % (k) for k in keyvalues)
+                ", ".join("%s = ?" % (k,) for k in updatevalues),
+                " AND ".join("%s = ?" % (k,) for k in keyvalues)
             )
 
         def func(txn):