summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py75
1 files changed, 65 insertions, 10 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index be9934c66f..3725c9795d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
 
 from twisted.internet import defer
 
-import collections
+from collections import namedtuple, OrderedDict
 import simplejson as json
 import sys
 import time
@@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
 
 
+# TODO(paul):
+#  * more generic key management
+#  * export monitoring stats
+#  * consider other eviction strategies - LRU?
+def cached(max_entries=1000):
+    """ A method decorator that applies a memoizing cache around the function.
+
+    The function is presumed to take one additional argument, which is used as
+    the key for the cache. Cache hits are served directly from the cache;
+    misses use the function body to generate the value.
+
+    The wrapped function has an additional member, a callable called
+    "invalidate". This can be used to remove individual entries from the cache.
+
+    The wrapped function has another additional callable, called "prefill",
+    which can be used to insert values into the cache specifically, without
+    calling the calculation function.
+    """
+    def wrap(orig):
+        cache = OrderedDict()
+
+        def prefill(key, value):
+            while len(cache) > max_entries:
+                cache.popitem(last=False)
+
+            cache[key] = value
+
+        @defer.inlineCallbacks
+        def wrapped(self, key):
+            if key in cache:
+                defer.returnValue(cache[key])
+
+            ret = yield orig(self, key)
+            prefill(key, ret)
+            defer.returnValue(ret)
+
+        def invalidate(key):
+            cache.pop(key, None)
+
+        wrapped.invalidate = invalidate
+        wrapped.prefill = prefill
+        return wrapped
+
+    return wrap
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging to the .execute() method."""
@@ -404,7 +450,8 @@ class SQLBaseStore(object):
 
         Args:
             table : string giving the table name
-            keyvalues : dict of column names and values to select the rows with
+            keyvalues : dict of column names and values to select the rows with,
+            or None to not apply a WHERE clause.
             retcols : list of strings giving the names of the columns to return
         """
         return self.runInteraction(
@@ -423,13 +470,20 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
-        sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
-        )
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+            )
+            txn.execute(sql, keyvalues.values())
+        else:
+            sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+                ", ".join(retcols),
+                table
+            )
+            txn.execute(sql)
 
-        txn.execute(sql, keyvalues.values())
         return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
@@ -586,8 +640,9 @@ class SQLBaseStore(object):
         start_time = time.time() * 1000
         update_counter = self._get_event_counters.update
 
+        cache = self._get_event_cache.setdefault(event_id, {})
+
         try:
-            cache = self._get_event_cache.setdefault(event_id, {})
             # Separate cache entries for each way to invoke _get_event_txn
             return cache[(check_redacted, get_prev_content, allow_rejected)]
         except KeyError:
@@ -786,7 +841,7 @@ class JoinHelper(object):
         for table in self.tables:
             res += [f for f in table.fields if f not in res]
 
-        self.EntryType = collections.namedtuple("JoinHelperEntry", res)
+        self.EntryType = namedtuple("JoinHelperEntry", res)
 
     def get_fields(self, **prefixes):
         """Get a string representing a list of fields for use in SELECT