diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96b7dba5fe..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
+from mock import Mock
+
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
- cache = Cache("test", max_entries=2, lru=True)
+ cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bab366fb7f..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from mock import Mock
+
class LruCacheTestCase(unittest.TestCase):
@@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
+ cache["key"] = 2 # Make sure overriding works.
+ self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
+
+
+class LruCacheCallbacksTestCase(unittest.TestCase):
+ def test_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_multi_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_set(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_pop(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ def test_del_multi(self):
+ m1 = Mock()
+ m2 = Mock()
+ m3 = Mock()
+ m4 = Mock()
+ cache = LruCache(4, 2, cache_type=TreeCache)
+
+ cache.set(("a", "1"), "value", m1)
+ cache.set(("a", "2"), "value", m2)
+ cache.set(("b", "1"), "value", m3)
+ cache.set(("b", "2"), "value", m4)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ cache.del_multi(("a",))
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ def test_clear(self):
+ m1 = Mock()
+ m2 = Mock()
+ cache = LruCache(5)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+
+ cache.clear()
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+
+ def test_eviction(self):
+ m1 = Mock(name="m1")
+ m2 = Mock(name="m2")
+ m3 = Mock(name="m3")
+ cache = LruCache(2)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value", m3)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.get("key2")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key1", "value", m1)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 1)
|