summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py89
-rw-r--r--tests/storage/test__base.py34
2 files changed, 87 insertions, 36 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2552a74f85..27ea65a0f6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
 )
 
 
+class Cache(object):
+
+    def __init__(self, name, max_entries=1000, keylen=1):
+        self.cache = OrderedDict()
+
+        self.max_entries = max_entries
+        self.name = name
+        self.keylen = keylen
+
+        caches_by_name[name] = self.cache
+
+    def get(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        if keyargs in self.cache:
+            cache_counter.inc_hits(self.name)
+            return self.cache[keyargs]
+
+        cache_counter.inc_misses(self.name)
+        raise KeyError()
+
+    def prefill(self, *args):  # because I can't  *keyargs, value
+        keyargs = args[:-1]
+        value = args[-1]
+
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        while len(self.cache) > self.max_entries:
+            self.cache.popitem(last=False)
+
+        self.cache[keyargs] = value
+
+    def invalidate(self, *keyargs):
+        if len(keyargs) != self.keylen:
+            raise ValueError("Expected a key to have %d items", self.keylen)
+
+        self.cache.pop(keyargs, None)
+
+
 # TODO(paul):
 #  * consider other eviction strategies - LRU?
 def cached(max_entries=1000, num_args=1):
@@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
     calling the calculation function.
     """
     def wrap(orig):
-        cache = OrderedDict()
-        name = orig.__name__
-
-        caches_by_name[name] = cache
-
-        def prefill(*args):  # because I can't  *keyargs, value
-            keyargs = args[:-1]
-            value = args[-1]
-
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            while len(cache) > max_entries:
-                cache.popitem(last=False)
-
-            cache[keyargs] = value
+        cache = Cache(
+            name=orig.__name__,
+            max_entries=max_entries,
+            keylen=num_args,
+        )
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
         def wrapped(self, *keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
-
-            if keyargs in cache:
-                cache_counter.inc_hits(name)
-                defer.returnValue(cache[keyargs])
-
-            cache_counter.inc_misses(name)
-            ret = yield orig(self, *keyargs)
-
-            prefill(*keyargs + (ret,))
-
-            defer.returnValue(ret)
+            try:
+                defer.returnValue(cache.get(*keyargs))
+            except KeyError:
+                ret = yield orig(self, *keyargs)
 
-        def invalidate(*keyargs):
-            if len(keyargs) != num_args:
-                raise ValueError("Expected a call to have %d arguments", num_args)
+                cache.prefill(*keyargs + (ret,))
 
-            cache.pop(keyargs, None)
+                defer.returnValue(ret)
 
-        wrapped.invalidate = invalidate
-        wrapped.prefill = prefill
+        wrapped.invalidate = cache.invalidate
+        wrapped.prefill = cache.prefill
         return wrapped
 
     return wrap
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 55d22f665a..783abc2b00 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,7 +17,39 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.storage._base import cached
+from synapse.storage._base import Cache, cached
+
+
+class CacheTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.cache = Cache("test")
+
+    def test_empty(self):
+        failed = False
+        try:
+            self.cache.get("foo")
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
+
+    def test_hit(self):
+        self.cache.prefill("foo", 123)
+
+        self.assertEquals(self.cache.get("foo"), 123)
+
+    def test_invalidate(self):
+        self.cache.prefill("foo", 123)
+        self.cache.invalidate("foo")
+
+        failed = False
+        try:
+            self.cache.get("foo")
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
 
 
 class CacheDecoratorTestCase(unittest.TestCase):