summary refs log tree commit diff
path: root/tests/util
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util')
-rw-r--r--tests/util/caches/test_deferred_cache.py251
-rw-r--r--tests/util/caches/test_descriptors.py381
-rw-r--r--tests/util/test_lrucache.py12
3 files changed, 598 insertions, 46 deletions
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
new file mode 100644
index 0000000000..dadfabd46d
--- /dev/null
+++ b/tests/util/caches/test_deferred_cache.py
@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+from twisted.internet import defer
+
+from synapse.util.caches.deferred_cache import DeferredCache
+
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+    def test_empty(self):
+        cache = DeferredCache("test")
+        failed = False
+        try:
+            cache.get("foo")
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
+
+    def test_hit(self):
+        cache = DeferredCache("test")
+        cache.prefill("foo", 123)
+
+        self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+    def test_hit_deferred(self):
+        cache = DeferredCache("test")
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d)
+
+        # get should return an incomplete deferred
+        get_d = cache.get("k1")
+        self.assertFalse(get_d.called)
+
+        # add a callback that will make sure that the set_d gets called before the get_d
+        def check1(r):
+            self.assertTrue(set_d.called)
+            return r
+
+        # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+        #   maybe we should fix that?
+        # get_d.addCallback(check1)
+
+        # now fire off all the deferreds
+        origin_d.callback(99)
+        self.assertEqual(self.successResultOf(origin_d), 99)
+        self.assertEqual(self.successResultOf(set_d), 99)
+        self.assertEqual(self.successResultOf(get_d), 99)
+
+    def test_callbacks(self):
+        """Invalidation callbacks are called at the right time"""
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # we don't expect the invalidation callback for the original value to have
+        # been called yet, even though get() will now return a different result.
+        # I'm not sure if that is by design or not.
+        self.assertEqual(callbacks, set())
+
+        # now fire off all the deferreds
+        origin_d.callback(20)
+        self.assertEqual(self.successResultOf(set_d), 20)
+        self.assertEqual(self.successResultOf(get_d), 20)
+
+        # now the original invalidation callback should have been called, but none of
+        # the others
+        self.assertEqual(callbacks, {"prefill"})
+        callbacks.clear()
+
+        # another update should invalidate both the previous results
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"set", "get"})
+
+    def test_set_fail(self):
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # none of the callbacks should have been called yet
+        self.assertEqual(callbacks, set())
+
+        # oh noes! fails!
+        e = Exception("oops")
+        origin_d.errback(e)
+        self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+        self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+        # the callbacks for the failed requests should have been called.
+        # I'm not sure if this is deliberate or not.
+        self.assertEqual(callbacks, {"get", "set"})
+        callbacks.clear()
+
+        # the old value should still be returned now?
+        get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+        self.assertEqual(self.successResultOf(get_d2), 10)
+
+        # replacing the value now should run the callbacks for those requests
+        # which got the original result
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"prefill", "get2"})
+
+    def test_get_immediate(self):
+        cache = DeferredCache("test")
+        d1 = defer.Deferred()
+        cache.set("key1", d1)
+
+        # get_immediate should return default
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 1)
+
+        # now complete the set
+        d1.callback(2)
+
+        # get_immediate should return result
+        v = cache.get_immediate("key1", 1)
+        self.assertEqual(v, 2)
+
+    def test_invalidate(self):
+        cache = DeferredCache("test")
+        cache.prefill(("foo",), 123)
+        cache.invalidate(("foo",))
+
+        failed = False
+        try:
+            cache.get(("foo",))
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
+
+    def test_invalidate_all(self):
+        cache = DeferredCache("testcache")
+
+        callback_record = [False, False]
+
+        def record_callback(idx):
+            callback_record[idx] = True
+
+        # add a couple of pending entries
+        d1 = defer.Deferred()
+        cache.set("key1", d1, partial(record_callback, 0))
+
+        d2 = defer.Deferred()
+        cache.set("key2", d2, partial(record_callback, 1))
+
+        # lookup should return pending deferreds
+        self.assertFalse(cache.get("key1").called)
+        self.assertFalse(cache.get("key2").called)
+
+        # let one of the lookups complete
+        d2.callback("result2")
+
+        # now the cache will return a completed deferred
+        self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
+
+        # now do the invalidation
+        cache.invalidate_all()
+
+        # lookup should fail
+        with self.assertRaises(KeyError):
+            cache.get("key1")
+        with self.assertRaises(KeyError):
+            cache.get("key2")
+
+        # both callbacks should have been callbacked
+        self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+        self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
+
+        # letting the other lookup complete should do nothing
+        d1.callback("result1")
+        with self.assertRaises(KeyError):
+            cache.get("key1", None)
+
+    def test_eviction(self):
+        cache = DeferredCache(
+            "test", max_entries=2, apply_cache_factor_from_config=False
+        )
+
+        cache.prefill(1, "one")
+        cache.prefill(2, "two")
+        cache.prefill(3, "three")  # 1 will be evicted
+
+        failed = False
+        try:
+            cache.get(1)
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
+
+        cache.get(2)
+        cache.get(3)
+
+    def test_eviction_lru(self):
+        cache = DeferredCache(
+            "test", max_entries=2, apply_cache_factor_from_config=False
+        )
+
+        cache.prefill(1, "one")
+        cache.prefill(2, "two")
+
+        # Now access 1 again, thus causing 2 to be least-recently used
+        cache.get(1)
+
+        cache.prefill(3, "three")
+
+        failed = False
+        try:
+            cache.get(2)
+        except KeyError:
+            failed = True
+
+        self.assertTrue(failed)
+
+        cache.get(1)
+        cache.get(3)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 677e925477..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from functools import partial
+from typing import Set
 
 import mock
 
@@ -29,60 +29,50 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
 
 from tests import unittest
+from tests.test_utils import get_awaitable_result
 
 logger = logging.getLogger(__name__)
 
 
-def run_on_reactor():
-    d = defer.Deferred()
-    reactor.callLater(0, d.callback, 0)
-    return make_deferred_yieldable(d)
-
-
-class CacheTestCase(unittest.TestCase):
-    def test_invalidate_all(self):
-        cache = descriptors.Cache("testcache")
-
-        callback_record = [False, False]
-
-        def record_callback(idx):
-            callback_record[idx] = True
-
-        # add a couple of pending entries
-        d1 = defer.Deferred()
-        cache.set("key1", d1, partial(record_callback, 0))
-
-        d2 = defer.Deferred()
-        cache.set("key2", d2, partial(record_callback, 1))
-
-        # lookup should return observable deferreds
-        self.assertFalse(cache.get("key1").has_called())
-        self.assertFalse(cache.get("key2").has_called())
+class LruCacheDecoratorTestCase(unittest.TestCase):
+    def test_base(self):
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
 
-        # let one of the lookups complete
-        d2.callback("result2")
+            @lru_cache()
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
 
-        # for now at least, the cache will return real results rather than an
-        # observabledeferred
-        self.assertEqual(cache.get("key2"), "result2")
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
 
-        # now do the invalidation
-        cache.invalidate_all()
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
 
-        # lookup should return none
-        self.assertIsNone(cache.get("key1", None))
-        self.assertIsNone(cache.get("key2", None))
+        # the two values should now be cached
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
 
-        # both callbacks should have been callbacked
-        self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
-        self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
 
-        # letting the other lookup complete should do nothing
-        d1.callback("result1")
-        self.assertIsNone(cache.get("key1", None))
+def run_on_reactor():
+    d = defer.Deferred()
+    reactor.callLater(0, d.callback, 0)
+    return make_deferred_yieldable(d)
 
 
 class DescriptorTestCase(unittest.TestCase):
@@ -174,6 +164,57 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_cache_with_async_exception(self):
+        """The wrapped function returns a failure
+        """
+
+        class Cls:
+            result = None
+            call_count = 0
+
+            @cached()
+            def fn(self, arg1):
+                self.call_count += 1
+                return self.result
+
+        obj = Cls()
+        callbacks = set()  # type: Set[str]
+
+        # set off an asynchronous request
+        obj.result = origin_d = defer.Deferred()
+
+        d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+        self.assertFalse(d1.called)
+
+        # a second request should also return a deferred, but should not call the
+        # function itself.
+        d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+        self.assertFalse(d2.called)
+        self.assertEqual(obj.call_count, 1)
+
+        # no callbacks yet
+        self.assertEqual(callbacks, set())
+
+        # the original request fails
+        e = Exception("bzz")
+        origin_d.errback(e)
+
+        # ... which should cause the lookups to fail similarly
+        self.assertIs(self.failureResultOf(d1, Exception).value, e)
+        self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+        # ... and the callbacks to have been, uh, called.
+        self.assertEqual(callbacks, {"d1", "d2"})
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should work as normal
+        obj.result = defer.succeed(100)
+        d3 = obj.fn(1)
+        self.assertEqual(self.successResultOf(d3), 100)
+        self.assertEqual(obj.call_count, 2)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
@@ -354,6 +395,260 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_invalidate_cascade(self):
+        """Invalidations should cascade up through cache contexts"""
+
+        class Cls:
+            @cached(cache_context=True)
+            async def func1(self, key, cache_context):
+                return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+            @cached(cache_context=True)
+            async def func2(self, key, cache_context):
+                return self.func3(key, on_invalidate=cache_context.invalidate)
+
+            @lru_cache(cache_context=True)
+            def func3(self, key, cache_context):
+                self.invalidate = cache_context.invalidate
+                return 42
+
+        obj = Cls()
+
+        top_invalidate = mock.Mock()
+        r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+        self.assertEqual(r, 42)
+        obj.invalidate()
+        top_invalidate.assert_called_once()
+
+
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+    """More tests for @cached
+
+    The following is a set of tests that got lost in a different file for a while.
+
+    There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+    duplicates would be removed and the two sets of classes combined.
+    """
+
+    @defer.inlineCallbacks
+    def test_passthrough(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
+
+    @defer.inlineCallbacks
+    def test_hit(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals(callcount[0], 1)
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        a.func.invalidate(("foo",))
+
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+
+    def test_invalidate_missing(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        A().func.invalidate(("what",))
+
+    @defer.inlineCallbacks
+    def test_max_entries(self):
+        callcount = [0]
+
+        class A:
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertEquals(callcount[0], 12)
+
+        # There must have been at least 2 evictions, meaning if we calculate
+        # all 12 values again, we must get called at least 2 more times
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertTrue(
+            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+        )
+
+    def test_prefill(self):
+        callcount = [0]
+
+        d = defer.succeed(123)
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return d
+
+        a = A()
+
+        a.func.prefill(("foo",), 456)
+
+        self.assertEquals(a.func("foo").result, 456)
+        self.assertEquals(callcount[0], 0)
+
+    @defer.inlineCallbacks
+    def test_invalidate_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func.invalidate(("foo",))
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 1)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+    @defer.inlineCallbacks
+    def test_eviction_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached(max_entries=2)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+        yield a.func2("foo2")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func("foo3")
+
+        self.assertEquals(callcount[0], 3)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 4)
+        self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
+
 
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
 
-from .. import unittest
+from tests import unittest
+from tests.unittest import override_config
 
 
 class LruCacheTestCase(unittest.HomeserverTestCase):
@@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         self.assertEquals(cache.pop("key"), None)
 
     def test_del_multi(self):
-        cache = LruCache(4, 2, cache_type=TreeCache)
+        cache = LruCache(4, keylen=2, cache_type=TreeCache)
         cache[("animal", "cat")] = "mew"
         cache[("animal", "dog")] = "woof"
         cache[("vehicles", "car")] = "vroom"
@@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         cache.clear()
         self.assertEquals(len(cache), 0)
 
+    @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
+    def test_special_size(self):
+        cache = LruCache(10, "mycache")
+        self.assertEqual(cache.max_size, 100)
+
 
 class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
     def test_get(self):
@@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         m2 = Mock()
         m3 = Mock()
         m4 = Mock()
-        cache = LruCache(4, 2, cache_type=TreeCache)
+        cache = LruCache(4, keylen=2, cache_type=TreeCache)
 
         cache.set(("a", "1"), "value", callbacks=[m1])
         cache.set(("a", "2"), "value", callbacks=[m2])