diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8e69b1e9cc..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,237 +15,9 @@
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached
-
from tests import unittest
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
- @defer.inlineCallbacks
- def test_passthrough(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- a = A()
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals((yield a.func("bar")), "bar")
-
- @defer.inlineCallbacks
- def test_hit(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals(callcount[0], 1)
-
- @defer.inlineCallbacks
- def test_invalidate(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- a.func.invalidate(("foo",))
-
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
-
- def test_invalidate_missing(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- A().func.invalidate(("what",))
-
- @defer.inlineCallbacks
- def test_max_entries(self):
- callcount = [0]
-
- class A:
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
-
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertEquals(callcount[0], 12)
-
- # There must have been at least 2 evictions, meaning if we calculate
- # all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertTrue(
- callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
- )
-
- def test_prefill(self):
- callcount = [0]
-
- d = defer.succeed(123)
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return d
-
- a = A()
-
- a.func.prefill(("foo",), ObservableDeferred(d))
-
- self.assertEquals(a.func("foo").result, d.result)
- self.assertEquals(callcount[0], 0)
-
- @defer.inlineCallbacks
- def test_invalidate_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func.invalidate(("foo",))
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 1)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- @defer.inlineCallbacks
- def test_eviction_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached(max_entries=2)
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
- yield a.func2("foo2")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func("foo3")
-
- self.assertEquals(callcount[0], 3)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 4)
- self.assertEquals(callcount2[0], 3)
-
- @defer.inlineCallbacks
- def test_double_get(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
- yield a.func2("foo")
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 2)
-
- a.func.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 3)
-
-
class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()
|