summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test__base.py48
-rw-r--r--tests/util/test_lrucache.py40
2 files changed, 88 insertions, 0 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 4fc3639de0..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
 from tests import unittest
 from twisted.internet import defer
 
+from mock import Mock
+
 from synapse.util.async import ObservableDeferred
 
 from synapse.util.caches.descriptors import Cache, cached
@@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 4)
         self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bacec2f465..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get("key"), 1)
         self.assertEquals(cache.setdefault("key", 2), 1)
         self.assertEquals(cache.get("key"), 1)
+        cache["key"] = 2  # Make sure overriding works.
+        self.assertEquals(cache.get("key"), 2)
 
     def test_pop(self):
         cache = LruCache(1)
@@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase):
 
 
 class LruCacheCallbacksTestCase(unittest.TestCase):
+    def test_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", "value")
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+    def test_multi_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
     def test_set(self):
         m = Mock()
         cache = LruCache(1)