diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index be9934c66f..c98dd36aed 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
from twisted.internet import defer
-import collections
+from collections import namedtuple, OrderedDict
import simplejson as json
import sys
import time
@@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+# TODO(paul):
+# * more generic key management
+# * export monitoring stats
+# * consider other eviction strategies - LRU?
+def cached(max_entries=1000):
+ """ A method decorator that applies a memoizing cache around the function.
+
+ The function is presumed to take one additional argument, which is used as
+ the key for the cache. Cache hits are served directly from the cache;
+ misses use the function body to generate the value.
+
+ The wrapped function has an additional member, a callable called
+ "invalidate". This can be used to remove individual entries from the cache.
+
+ The wrapped function has another additional callable, called "prefill",
+ which can be used to insert values into the cache specifically, without
+ calling the calculation function.
+ """
+ def wrap(orig):
+ cache = OrderedDict()
+
+ def prefill(key, value):
+ while len(cache) > max_entries:
+ cache.popitem(last=False)
+
+ cache[key] = value
+
+ @defer.inlineCallbacks
+ def wrapped(self, key):
+ if key in cache:
+ defer.returnValue(cache[key])
+
+ ret = yield orig(self, key)
+ prefill(key, ret)
+ defer.returnValue(ret)
+
+ def invalidate(key):
+ cache.pop(key, None)
+
+ wrapped.invalidate = invalidate
+ wrapped.prefill = prefill
+ return wrapped
+
+ return wrap
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
@@ -586,8 +632,9 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
update_counter = self._get_event_counters.update
+ cache = self._get_event_cache.setdefault(event_id, {})
+
try:
- cache = self._get_event_cache.setdefault(event_id, {})
# Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError:
@@ -786,7 +833,7 @@ class JoinHelper(object):
for table in self.tables:
res += [f for f in table.fields if f not in res]
- self.EntryType = collections.namedtuple("JoinHelperEntry", res)
+ self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
|