diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3d1f960869..2ad08f541b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Set
import mock
@@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_cache_with_async_exception(self):
+ """The wrapped function returns a failure
+ """
+
+ class Cls:
+ result = None
+ call_count = 0
+
+ @cached()
+ def fn(self, arg1):
+ self.call_count += 1
+ return self.result
+
+ obj = Cls()
+ callbacks = set() # type: Set[str]
+
+ # set off an asynchronous request
+ obj.result = origin_d = defer.Deferred()
+
+ d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+ self.assertFalse(d1.called)
+
+ # a second request should also return a deferred, but should not call the
+ # function itself.
+ d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+ self.assertFalse(d2.called)
+ self.assertEqual(obj.call_count, 1)
+
+ # no callbacks yet
+ self.assertEqual(callbacks, set())
+
+ # the original request fails
+ e = Exception("bzz")
+ origin_d.errback(e)
+
+ # ... which should cause the lookups to fail similarly
+ self.assertIs(self.failureResultOf(d1, Exception).value, e)
+ self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+ # ... and the callbacks to have been, uh, called.
+ self.assertEqual(callbacks, {"d1", "d2"})
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should work as normal
+ obj.result = defer.succeed(100)
+ d3 = obj.fn(1)
+ self.assertEqual(self.successResultOf(d3), 100)
+ self.assertEqual(obj.call_count, 2)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError)
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+ """More tests for @cached
+
+ The following is a set of tests that got lost in a different file for a while.
+
+ There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+ duplicates would be removed and the two sets of classes combined.
+ """
+
+ @defer.inlineCallbacks
+ def test_passthrough(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
+
+ @defer.inlineCallbacks
+ def test_hit(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals(callcount[0], 1)
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ a.func.invalidate(("foo",))
+
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+
+ def test_invalidate_missing(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ A().func.invalidate(("what",))
+
+ @defer.inlineCallbacks
+ def test_max_entries(self):
+ callcount = [0]
+
+ class A:
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertEquals(callcount[0], 12)
+
+ # There must have been at least 2 evictions, meaning if we calculate
+ # all 12 values again, we must get called at least 2 more times
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertTrue(
+ callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+ )
+
+ def test_prefill(self):
+ callcount = [0]
+
+ d = defer.succeed(123)
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return d
+
+ a = A()
+
+ a.func.prefill(("foo",), 456)
+
+ self.assertEquals(a.func("foo").result, 456)
+ self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
+
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
|