diff options
Diffstat (limited to 'tests/util/caches')
-rw-r--r-- | tests/util/caches/test_deferred_cache.py | 113 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 281 |
2 files changed, 385 insertions, 9 deletions
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 8a08ab6661..dadfabd46d 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from functools import partial from twisted.internet import defer from synapse.util.caches.deferred_cache import DeferredCache +from tests.unittest import TestCase -class DeferredCacheTestCase(unittest.TestCase): + +class DeferredCacheTestCase(TestCase): def test_empty(self): cache = DeferredCache("test") failed = False @@ -36,7 +37,102 @@ class DeferredCacheTestCase(unittest.TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(cache.get("foo"), 123) + self.assertEquals(self.successResultOf(cache.get("foo")), 123) + + def test_hit_deferred(self): + cache = DeferredCache("test") + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d) + + # get should return an incomplete deferred + get_d = cache.get("k1") + self.assertFalse(get_d.called) + + # add a callback that will make sure that the set_d gets called before the get_d + def check1(r): + self.assertTrue(set_d.called) + return r + + # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8. + # maybe we should fix that? + # get_d.addCallback(check1) + + # now fire off all the deferreds + origin_d.callback(99) + self.assertEqual(self.successResultOf(origin_d), 99) + self.assertEqual(self.successResultOf(set_d), 99) + self.assertEqual(self.successResultOf(get_d), 99) + + def test_callbacks(self): + """Invalidation callbacks are called at the right time""" + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # we don't expect the invalidation callback for the original value to have + # been called yet, even though get() will now return a different result. + # I'm not sure if that is by design or not. + self.assertEqual(callbacks, set()) + + # now fire off all the deferreds + origin_d.callback(20) + self.assertEqual(self.successResultOf(set_d), 20) + self.assertEqual(self.successResultOf(get_d), 20) + + # now the original invalidation callback should have been called, but none of + # the others + self.assertEqual(callbacks, {"prefill"}) + callbacks.clear() + + # another update should invalidate both the previous results + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"set", "get"}) + + def test_set_fail(self): + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # none of the callbacks should have been called yet + self.assertEqual(callbacks, set()) + + # oh noes! fails! + e = Exception("oops") + origin_d.errback(e) + self.assertIs(self.failureResultOf(set_d, Exception).value, e) + self.assertIs(self.failureResultOf(get_d, Exception).value, e) + + # the callbacks for the failed requests should have been called. + # I'm not sure if this is deliberate or not. + self.assertEqual(callbacks, {"get", "set"}) + callbacks.clear() + + # the old value should still be returned now? + get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2")) + self.assertEqual(self.successResultOf(get_d2), 10) + + # replacing the value now should run the callbacks for those requests + # which got the original result + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"prefill", "get2"}) def test_get_immediate(self): cache = DeferredCache("test") @@ -82,16 +178,15 @@ class DeferredCacheTestCase(unittest.TestCase): d2 = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) + # lookup should return pending deferreds + self.assertFalse(cache.get("key1").called) + self.assertFalse(cache.get("key2").called) # let one of the lookups complete d2.callback("result2") - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") + # now the cache will return a completed deferred + self.assertEqual(self.successResultOf(cache.get("key2")), "result2") # now do the invalidation cache.invalidate_all() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3d1f960869..2ad08f541b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Set import mock @@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_cache_with_async_exception(self): + """The wrapped function returns a failure + """ + + class Cls: + result = None + call_count = 0 + + @cached() + def fn(self, arg1): + self.call_count += 1 + return self.result + + obj = Cls() + callbacks = set() # type: Set[str] + + # set off an asynchronous request + obj.result = origin_d = defer.Deferred() + + d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) + self.assertFalse(d1.called) + + # a second request should also return a deferred, but should not call the + # function itself. + d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2")) + self.assertFalse(d2.called) + self.assertEqual(obj.call_count, 1) + + # no callbacks yet + self.assertEqual(callbacks, set()) + + # the original request fails + e = Exception("bzz") + origin_d.errback(e) + + # ... which should cause the lookups to fail similarly + self.assertIs(self.failureResultOf(d1, Exception).value, e) + self.assertIs(self.failureResultOf(d2, Exception).value, e) + + # ... and the callbacks to have been, uh, called. + self.assertEqual(callbacks, {"d1", "d2"}) + + # ... leaving the cache empty + self.assertEqual(len(obj.fn.cache.cache), 0) + + # and a second call should work as normal + obj.result = defer.succeed(100) + d3 = obj.fn(1) + self.assertEqual(self.successResultOf(d3), 100) + self.assertEqual(obj.call_count, 2) + def test_cache_logcontexts(self): """Check that logcontexts are set and restored correctly when using the cache.""" @@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase): self.failureResultOf(d, SynapseError) +class CacheDecoratorTestCase(unittest.HomeserverTestCase): + """More tests for @cached + + The following is a set of tests that got lost in a different file for a while. + + There are probably duplicates of the tests in DescriptorTestCase. Ideally the + duplicates would be removed and the two sets of classes combined. + """ + + @defer.inlineCallbacks + def test_passthrough(self): + class A: + @cached() + def func(self, key): + return key + + a = A() + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals((yield a.func("bar")), "bar") + + @defer.inlineCallbacks + def test_hit(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals(callcount[0], 1) + + @defer.inlineCallbacks + def test_invalidate(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + a.func.invalidate(("foo",)) + + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + + def test_invalidate_missing(self): + class A: + @cached() + def func(self, key): + return key + + A().func.invalidate(("what",)) + + @defer.inlineCallbacks + def test_max_entries(self): + callcount = [0] + + class A: + @cached(max_entries=10) + def func(self, key): + callcount[0] += 1 + return key + + a = A() + + for k in range(0, 12): + yield a.func(k) + + self.assertEquals(callcount[0], 12) + + # There must have been at least 2 evictions, meaning if we calculate + # all 12 values again, we must get called at least 2 more times + for k in range(0, 12): + yield a.func(k) + + self.assertTrue( + callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) + ) + + def test_prefill(self): + callcount = [0] + + d = defer.succeed(123) + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return d + + a = A() + + a.func.prefill(("foo",), 456) + + self.assertEquals(a.func("foo").result, 456) + self.assertEquals(callcount[0], 0) + + @defer.inlineCallbacks + def test_invalidate_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func.invalidate(("foo",)) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 1) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + @defer.inlineCallbacks + def test_eviction_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached(max_entries=2) + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + yield a.func2("foo2") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func("foo3") + + self.assertEquals(callcount[0], 3) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 4) + self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) + + class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): |