diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..8d33def6c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear()
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def wrap(orig):
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+ self.orig = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ def __get__(self, obj, objtype=None):
cache = Cache(
- name=orig.__name__,
- max_entries=max_entries,
- keylen=num_args,
- lru=lru,
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
)
- @functools.wraps(orig)
+ @functools.wraps(self.orig)
@defer.inlineCallbacks
- def wrapped(self, *keyargs):
+ def wrapped(*keyargs):
try:
- cached_result = cache.get(*keyargs)
+ cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
- actual_result = yield orig(self, *keyargs)
+ actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
- orig.__name__, keyargs,
+ self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
@@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
- ret = yield orig(self, *keyargs)
+ ret = yield self.orig(obj, *keyargs)
- cache.update(sequence, *keyargs + (ret,))
+ cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
+
+ obj.__dict__[self.orig.__name__] = wrapped
+
return wrapped
- return wrap
+
+def cached(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
class LoggingTransaction(object):
|