summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py173
1 files changed, 110 insertions, 63 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 37bb28e6cf..d038c55092 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -25,6 +25,7 @@ import synapse.metrics
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
+import functools
 import simplejson as json
 import sys
 import time
@@ -53,13 +54,12 @@ cache_counter = metrics.register_cache(
 
 
 # TODO(paul):
-#  * more generic key management
 #  * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+def cached(max_entries=1000, num_args=1):
     """ A method decorator that applies a memoizing cache around the function.
 
-    The function is presumed to take one additional argument, which is used as
-    the key for the cache. Cache hits are served directly from the cache;
+    The function is presumed to take zero or more arguments, which are used in
+    a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
 
     The wrapped function has an additional member, a callable called
@@ -75,25 +75,41 @@ def cached(max_entries=1000):
 
         caches_by_name[name] = cache
 
-        def prefill(key, value):
+        def prefill(*args):  # because I can't  *keyargs, value
+            keyargs = args[:-1]
+            value = args[-1]
+
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
             while len(cache) > max_entries:
                 cache.popitem(last=False)
 
-            cache[key] = value
+            cache[keyargs] = value
 
+        @functools.wraps(orig)
         @defer.inlineCallbacks
-        def wrapped(self, key):
-            if key in cache:
+        def wrapped(self, *keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            if keyargs in cache:
                 cache_counter.inc_hits(name)
-                defer.returnValue(cache[key])
+                defer.returnValue(cache[keyargs])
 
             cache_counter.inc_misses(name)
-            ret = yield orig(self, key)
-            prefill(key, ret)
+            ret = yield orig(self, *keyargs)
+
+            prefill_args = keyargs + (ret,)
+            prefill(*prefill_args)
+
             defer.returnValue(ret)
 
-        def invalidate(key):
-            cache.pop(key, None)
+        def invalidate(*keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            cache.pop(keyargs, None)
 
         wrapped.invalidate = invalidate
         wrapped.prefill = prefill
@@ -325,7 +341,8 @@ class SQLBaseStore(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
+    def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
+                       desc="_simple_insert"):
         """Executes an INSERT query on the named table.
 
         Args:
@@ -334,7 +351,7 @@ class SQLBaseStore(object):
             or_replace : bool; if True performs an INSERT OR REPLACE
         """
         return self.runInteraction(
-            "_simple_insert",
+            desc,
             self._simple_insert_txn, table, values, or_replace=or_replace,
             or_ignore=or_ignore,
         )
@@ -357,7 +374,7 @@ class SQLBaseStore(object):
         txn.execute(sql, values.values())
         return txn.lastrowid
 
-    def _simple_upsert(self, table, keyvalues, values):
+    def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
         """
         Args:
             table (str): The table to upsert into
@@ -366,7 +383,7 @@ class SQLBaseStore(object):
         Returns: A deferred
         """
         return self.runInteraction(
-            "_simple_upsert",
+            desc,
             self._simple_upsert_txn, table, keyvalues, values
         )
 
@@ -402,7 +419,7 @@ class SQLBaseStore(object):
             txn.execute(sql, allvalues.values())
 
     def _simple_select_one(self, table, keyvalues, retcols,
-                           allow_none=False):
+                           allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
@@ -414,12 +431,15 @@ class SQLBaseStore(object):
             allow_none : If true, return None instead of failing if the SELECT
               statement returns no rows
         """
-        return self._simple_selectupdate_one(
-            table, keyvalues, retcols=retcols, allow_none=allow_none
+        return self.runInteraction(
+            desc,
+            self._simple_select_one_txn,
+            table, keyvalues, retcols, allow_none,
         )
 
     def _simple_select_one_onecol(self, table, keyvalues, retcol,
-                                  allow_none=False):
+                                  allow_none=False,
+                                  desc="_simple_select_one_onecol"):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it."
 
@@ -429,7 +449,7 @@ class SQLBaseStore(object):
             retcol : string giving the name of the column to return
         """
         return self.runInteraction(
-            "_simple_select_one_onecol",
+            desc,
             self._simple_select_one_onecol_txn,
             table, keyvalues, retcol, allow_none=allow_none,
         )
@@ -464,7 +484,8 @@ class SQLBaseStore(object):
 
         return [r[0] for r in txn.fetchall()]
 
-    def _simple_select_onecol(self, table, keyvalues, retcol):
+    def _simple_select_onecol(self, table, keyvalues, retcol,
+                              desc="_simple_select_onecol"):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
@@ -477,12 +498,13 @@ class SQLBaseStore(object):
             Deferred: Results in a list
         """
         return self.runInteraction(
-            "_simple_select_onecol",
+            desc,
             self._simple_select_onecol_txn,
             table, keyvalues, retcol
         )
 
-    def _simple_select_list(self, table, keyvalues, retcols):
+    def _simple_select_list(self, table, keyvalues, retcols,
+                            desc="_simple_select_list"):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -493,7 +515,7 @@ class SQLBaseStore(object):
             retcols : list of strings giving the names of the columns to return
         """
         return self.runInteraction(
-            "_simple_select_list",
+            desc,
             self._simple_select_list_txn,
             table, keyvalues, retcols
         )
@@ -525,7 +547,7 @@ class SQLBaseStore(object):
         return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
-                           retcols=None):
+                           desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
@@ -543,45 +565,70 @@ class SQLBaseStore(object):
         get-and-set.  This can be used to implement compare-and-set by putting
         the update column in the 'keyvalues' dict as well.
         """
-        return self._simple_selectupdate_one(table, keyvalues, updatevalues,
-                                             retcols=retcols)
+        return self.runInteraction(
+            desc,
+            self._simple_update_one_txn,
+            table, keyvalues, updatevalues,
+        )
 
-    def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
-                                 retcols=None, allow_none=False):
-        """ Combined SELECT then UPDATE."""
-        if retcols:
-            select_sql = "SELECT %s FROM %s WHERE %s" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k) for k in keyvalues)
-            )
+    def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+        update_sql = "UPDATE %s SET %s WHERE %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+        )
 
-        if updatevalues:
-            update_sql = "UPDATE %s SET %s WHERE %s" % (
-                table,
-                ", ".join("%s = ?" % (k,) for k in updatevalues),
-                " AND ".join("%s = ?" % (k,) for k in keyvalues)
-            )
+        txn.execute(
+            update_sql,
+            updatevalues.values() + keyvalues.values()
+        )
+
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched")
+
+    def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+                               allow_none=False):
+        select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join("%s = ?" % (k) for k in keyvalues)
+        )
+
+        txn.execute(select_sql, keyvalues.values())
+
+        row = txn.fetchone()
+        if not row:
+            if allow_none:
+                return None
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched")
+
+        return dict(zip(retcols, row))
 
+    def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+                                 retcols=None, allow_none=False,
+                                 desc="_simple_selectupdate_one"):
+        """ Combined SELECT then UPDATE."""
         def func(txn):
             ret = None
             if retcols:
-                txn.execute(select_sql, keyvalues.values())
-
-                row = txn.fetchone()
-                if not row:
-                    if allow_none:
-                        return None
-                    raise StoreError(404, "No row found")
-                if txn.rowcount > 1:
-                    raise StoreError(500, "More than one row matched")
-
-                ret = dict(zip(retcols, row))
+                ret = self._simple_select_one_txn(
+                    txn,
+                    table=table,
+                    keyvalues=keyvalues,
+                    retcols=retcols,
+                    allow_none=allow_none,
+                )
 
             if updatevalues:
-                txn.execute(
-                    update_sql,
-                    updatevalues.values() + keyvalues.values()
+                self._simple_update_one_txn(
+                    txn,
+                    table=table,
+                    keyvalues=keyvalues,
+                    updatevalues=updatevalues,
                 )
 
                 # if txn.rowcount == 0:
@@ -590,9 +637,9 @@ class SQLBaseStore(object):
                     raise StoreError(500, "More than one row matched")
 
             return ret
-        return self.runInteraction("_simple_selectupdate_one", func)
+        return self.runInteraction(desc, func)
 
-    def _simple_delete_one(self, table, keyvalues):
+    def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
@@ -611,9 +658,9 @@ class SQLBaseStore(object):
                 raise StoreError(404, "No row found")
             if txn.rowcount > 1:
                 raise StoreError(500, "more than one row matched")
-        return self.runInteraction("_simple_delete_one", func)
+        return self.runInteraction(desc, func)
 
-    def _simple_delete(self, table, keyvalues):
+    def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
         """Executes a DELETE query on the named table.
 
         Args:
@@ -621,7 +668,7 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the row with
         """
 
-        return self.runInteraction("_simple_delete", self._simple_delete_txn)
+        return self.runInteraction(desc, self._simple_delete_txn)
 
     def _simple_delete_txn(self, txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (