summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-04-07 18:05:39 +0100
committerErik Johnston <erik@matrix.org>2015-04-07 18:05:39 +0100
commit4fe95094d1aa9a8a36a32c56d5665ddba825e029 (patch)
treed7e88a7b2ce0d41403c7a7afaff3b44088e60324 /synapse/storage/_base.py
parentRetry on deadlock (diff)
parentupdate leo's contribs a bit (diff)
downloadsynapse-4fe95094d1aa9a8a36a32c56d5665ddba825e029.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into mysql
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py100
1 files changed, 61 insertions, 39 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c15cec0c78..20fc1d0bb9 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -54,9 +54,53 @@ cache_counter = metrics.register_cache(
 )
 
 
-# TODO(paul):
-#  * consider other eviction strategies - LRU?
-def cached(max_entries=1000, num_args=1):
+class Cache(object):
+
+    def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+        if lru:
+            self.cache = LruCache(max_size=max_entries)
+            self.max_entries = None
+        else:
+            self.cache = OrderedDict()
+            self.max_entries = max_entries
+
+        self.name = name
+        self.keylen = keylen
+
+        caches_by_name[name] = self.cache
+
+    def get(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if keyargs in self.cache:
+            cache_counter.inc_hits(self.name)
+            return self.cache[keyargs]
+
+        cache_counter.inc_misses(self.name)
+        raise KeyError()
+
+    def prefill(self, *args):  # because I can't  *keyargs, value
+        keyargs = args[:-1]
+        value = args[-1]
+
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if self.max_entries is not None:
+            while len(self.cache) >= self.max_entries:
+                self.cache.popitem(last=False)
+
+        self.cache[keyargs] = value
+
+    def invalidate(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
     """ A method decorator that applies a memoizing cache around the function.
 
     The function is presumed to take zero or more arguments, which are used in
@@ -71,49 +115,27 @@ def cached(max_entries=1000, num_args=1):
     calling the calculation function.
     """
     def wrap(orig):
-        cache = OrderedDict()
-        name = orig.__name__
-
-        caches_by_name[name] = cache
-
-        def prefill(*args):  # because I can't  *keyargs, value
-            keyargs = args[:-1]
-            value = args[-1]
-
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            while len(cache) > max_entries:
-                cache.popitem(last=False)
-
-            cache[keyargs] = value
+        cache = Cache(
+            name=orig.__name__,
+            max_entries=max_entries,
+            keylen=num_args,
+            lru=lru,
+        )
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
         def wrapped(self, *keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            if keyargs in cache:
-                cache_counter.inc_hits(name)
-                defer.returnValue(cache[keyargs])
-
-            cache_counter.inc_misses(name)
-            ret = yield orig(self, *keyargs)
-
-            prefill_args = keyargs + (ret,)
-            prefill(*prefill_args)
-
-            defer.returnValue(ret)
+            try:
+                defer.returnValue(cache.get(*keyargs))
+            except KeyError:
+                ret = yield orig(self, *keyargs)
 
-        def invalidate(*keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
+                cache.prefill(*keyargs + (ret,))
 
-            cache.pop(keyargs, None)
+                defer.returnValue(ret)
 
-        wrapped.invalidate = invalidate
-        wrapped.prefill = prefill
+        wrapped.invalidate = cache.invalidate
+        wrapped.prefill = cache.prefill
         return wrapped
 
     return wrap