summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test__base.py228
-rw-r--r--tests/util/caches/test_descriptors.py230
2 files changed, 230 insertions, 228 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8e69b1e9cc..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,237 +15,9 @@
 # limitations under the License.
 
 
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached
-
 from tests import unittest
 
 
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
-    @defer.inlineCallbacks
-    def test_passthrough(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        a = A()
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals((yield a.func("bar")), "bar")
-
-    @defer.inlineCallbacks
-    def test_hit(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals(callcount[0], 1)
-
-    @defer.inlineCallbacks
-    def test_invalidate(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        a.func.invalidate(("foo",))
-
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-
-    def test_invalidate_missing(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        A().func.invalidate(("what",))
-
-    @defer.inlineCallbacks
-    def test_max_entries(self):
-        callcount = [0]
-
-        class A:
-            @cached(max_entries=10)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertEquals(callcount[0], 12)
-
-        # There must have been at least 2 evictions, meaning if we calculate
-        # all 12 values again, we must get called at least 2 more times
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertTrue(
-            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
-        )
-
-    def test_prefill(self):
-        callcount = [0]
-
-        d = defer.succeed(123)
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return d
-
-        a = A()
-
-        a.func.prefill(("foo",), ObservableDeferred(d))
-
-        self.assertEquals(a.func("foo").result, d.result)
-        self.assertEquals(callcount[0], 0)
-
-    @defer.inlineCallbacks
-    def test_invalidate_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func.invalidate(("foo",))
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 1)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-    @defer.inlineCallbacks
-    def test_eviction_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached(max_entries=2)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-        yield a.func2("foo2")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func("foo3")
-
-        self.assertEquals(callcount[0], 3)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 4)
-        self.assertEquals(callcount2[0], 3)
-
-    @defer.inlineCallbacks
-    def test_double_get(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
-        yield a.func2("foo")
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 2)
-
-        a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 3)
-
-
 class UpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.storage = hs.get_datastore()
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3d1f960869..3d738afa7f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import descriptors
 from synapse.util.caches.descriptors import cached
 
@@ -311,6 +312,235 @@ class DescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d, SynapseError)
 
 
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+    """More tests for @cached
+
+    The following is a set of tests that got lost in a different file for a while.
+
+    There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+    duplicates would be removed and the two sets of classes combined.
+    """
+
+    @defer.inlineCallbacks
+    def test_passthrough(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
+
+    @defer.inlineCallbacks
+    def test_hit(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals(callcount[0], 1)
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        a.func.invalidate(("foo",))
+
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+
+    def test_invalidate_missing(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        A().func.invalidate(("what",))
+
+    @defer.inlineCallbacks
+    def test_max_entries(self):
+        callcount = [0]
+
+        class A:
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertEquals(callcount[0], 12)
+
+        # There must have been at least 2 evictions, meaning if we calculate
+        # all 12 values again, we must get called at least 2 more times
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertTrue(
+            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+        )
+
+    def test_prefill(self):
+        callcount = [0]
+
+        d = defer.succeed(123)
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return d
+
+        a = A()
+
+        a.func.prefill(("foo",), ObservableDeferred(d))
+
+        self.assertEquals(a.func("foo").result, d.result)
+        self.assertEquals(callcount[0], 0)
+
+    @defer.inlineCallbacks
+    def test_invalidate_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func.invalidate(("foo",))
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 1)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+    @defer.inlineCallbacks
+    def test_eviction_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached(max_entries=2)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+        yield a.func2("foo2")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func("foo3")
+
+        self.assertEquals(callcount[0], 3)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 4)
+        self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
+
+
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cache(self):