summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py89
1 files changed, 54 insertions, 35 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2552a74f85..27ea65a0f6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
 )
 
 
+class Cache(object):
+
+    def __init__(self, name, max_entries=1000, keylen=1):
+        self.cache = OrderedDict()
+
+        self.max_entries = max_entries
+        self.name = name
+        self.keylen = keylen
+
+        caches_by_name[name] = self.cache
+
+    def get(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if keyargs in self.cache:
+            cache_counter.inc_hits(self.name)
+            return self.cache[keyargs]
+
+        cache_counter.inc_misses(self.name)
+        raise KeyError()
+
+    def prefill(self, *args):  # because I can't  *keyargs, value
+        keyargs = args[:-1]
+        value = args[-1]
+
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        while len(self.cache) > self.max_entries:
+            self.cache.popitem(last=False)
+
+        self.cache[keyargs] = value
+
+    def invalidate(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        self.cache.pop(keyargs, None)
+
+
 # TODO(paul):
 #  * consider other eviction strategies - LRU?
 def cached(max_entries=1000, num_args=1):
@@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
     calling the calculation function.
     """
     def wrap(orig):
-        cache = OrderedDict()
-        name = orig.__name__
-
-        caches_by_name[name] = cache
-
-        def prefill(*args):  # because I can't  *keyargs, value
-            keyargs = args[:-1]
-            value = args[-1]
-
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            while len(cache) > max_entries:
-                cache.popitem(last=False)
-
-            cache[keyargs] = value
+        cache = Cache(
+            name=orig.__name__,
+            max_entries=max_entries,
+            keylen=num_args,
+        )
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
         def wrapped(self, *keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            if keyargs in cache:
-                cache_counter.inc_hits(name)
-                defer.returnValue(cache[keyargs])
-
-            cache_counter.inc_misses(name)
-            ret = yield orig(self, *keyargs)
-
-            prefill(*keyargs + (ret,))
-
-            defer.returnValue(ret)
+            try:
+                defer.returnValue(cache.get(*keyargs))
+            except KeyError:
+                ret = yield orig(self, *keyargs)
 
-        def invalidate(*keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
+                cache.prefill(*keyargs + (ret,))
 
-            cache.pop(keyargs, None)
+                defer.returnValue(ret)
 
-        wrapped.invalidate = invalidate
-        wrapped.prefill = prefill
+        wrapped.invalidate = cache.invalidate
+        wrapped.prefill = cache.prefill
         return wrapped
 
     return wrap