diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c15cec0c78..20fc1d0bb9 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -54,9 +54,53 @@ cache_counter = metrics.register_cache(
)
-# TODO(paul):
-# * consider other eviction strategies - LRU?
-def cached(max_entries=1000, num_args=1):
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
+
+ self.name = name
+ self.keylen = keylen
+
+ caches_by_name[name] = self.cache
+
+ def get(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if keyargs in self.cache:
+ cache_counter.inc_hits(self.name)
+ return self.cache[keyargs]
+
+ cache_counter.inc_misses(self.name)
+ raise KeyError()
+
+ def prefill(self, *args): # because I can't *keyargs, value
+ keyargs = args[:-1]
+ value = args[-1]
+
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[keyargs] = value
+
+ def invalidate(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@@ -71,49 +115,27 @@ def cached(max_entries=1000, num_args=1):
calling the calculation function.
"""
def wrap(orig):
- cache = OrderedDict()
- name = orig.__name__
-
- caches_by_name[name] = cache
-
- def prefill(*args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
-
- while len(cache) > max_entries:
- cache.popitem(last=False)
-
- cache[keyargs] = value
+ cache = Cache(
+ name=orig.__name__,
+ max_entries=max_entries,
+ keylen=num_args,
+ lru=lru,
+ )
@functools.wraps(orig)
@defer.inlineCallbacks
def wrapped(self, *keyargs):
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
-
- if keyargs in cache:
- cache_counter.inc_hits(name)
- defer.returnValue(cache[keyargs])
-
- cache_counter.inc_misses(name)
- ret = yield orig(self, *keyargs)
-
- prefill_args = keyargs + (ret,)
- prefill(*prefill_args)
-
- defer.returnValue(ret)
+ try:
+ defer.returnValue(cache.get(*keyargs))
+ except KeyError:
+ ret = yield orig(self, *keyargs)
- def invalidate(*keyargs):
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
+ cache.prefill(*keyargs + (ret,))
- cache.pop(keyargs, None)
+ defer.returnValue(ret)
- wrapped.invalidate = invalidate
- wrapped.prefill = prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.prefill = cache.prefill
return wrapped
return wrap
|