diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2552a74f85..27ea65a0f6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
)
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1):
+ self.cache = OrderedDict()
+
+ self.max_entries = max_entries
+ self.name = name
+ self.keylen = keylen
+
+ caches_by_name[name] = self.cache
+
+ def get(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if keyargs in self.cache:
+ cache_counter.inc_hits(self.name)
+ return self.cache[keyargs]
+
+ cache_counter.inc_misses(self.name)
+ raise KeyError()
+
+ def prefill(self, *args): # because I can't *keyargs, value
+ keyargs = args[:-1]
+ value = args[-1]
+
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ while len(self.cache) > self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[keyargs] = value
+
+ def invalidate(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ self.cache.pop(keyargs, None)
+
+
# TODO(paul):
# * consider other eviction strategies - LRU?
def cached(max_entries=1000, num_args=1):
@@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
calling the calculation function.
"""
def wrap(orig):
- cache = OrderedDict()
- name = orig.__name__
-
- caches_by_name[name] = cache
-
- def prefill(*args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
-
- while len(cache) > max_entries:
- cache.popitem(last=False)
-
- cache[keyargs] = value
+ cache = Cache(
+ name=orig.__name__,
+ max_entries=max_entries,
+ keylen=num_args,
+ )
@functools.wraps(orig)
@defer.inlineCallbacks
def wrapped(self, *keyargs):
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
-
- if keyargs in cache:
- cache_counter.inc_hits(name)
- defer.returnValue(cache[keyargs])
-
- cache_counter.inc_misses(name)
- ret = yield orig(self, *keyargs)
-
- prefill(*keyargs + (ret,))
-
- defer.returnValue(ret)
+ try:
+ defer.returnValue(cache.get(*keyargs))
+ except KeyError:
+ ret = yield orig(self, *keyargs)
- def invalidate(*keyargs):
- if len(keyargs) != num_args:
- raise ValueError("Expected a call to have %d arguments", num_args)
+ cache.prefill(*keyargs + (ret,))
- cache.pop(keyargs, None)
+ defer.returnValue(ret)
- wrapped.invalidate = invalidate
- wrapped.prefill = prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.prefill = cache.prefill
return wrapped
return wrap
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 55d22f665a..783abc2b00 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,7 +17,39 @@
from tests import unittest
from twisted.internet import defer
-from synapse.storage._base import cached
+from synapse.storage._base import Cache, cached
+
+
+class CacheTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.cache = Cache("test")
+
+ def test_empty(self):
+ failed = False
+ try:
+ self.cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_hit(self):
+ self.cache.prefill("foo", 123)
+
+ self.assertEquals(self.cache.get("foo"), 123)
+
+ def test_invalidate(self):
+ self.cache.prefill("foo", 123)
+ self.cache.invalidate("foo")
+
+ failed = False
+ try:
+ self.cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
class CacheDecoratorTestCase(unittest.TestCase):
|