summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8572.misc1
-rw-r--r--synapse/util/caches/deferred_cache.py57
-rw-r--r--synapse/util/caches/descriptors.py32
-rw-r--r--tests/storage/test__base.py228
-rw-r--r--tests/util/caches/test_deferred_cache.py113
-rw-r--r--tests/util/caches/test_descriptors.py281
6 files changed, 441 insertions, 271 deletions
diff --git a/changelog.d/8572.misc b/changelog.d/8572.misc
new file mode 100644
index 0000000000..ea2a6d340d
--- /dev/null
+++ b/changelog.d/8572.misc
@@ -0,0 +1 @@
+Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index faeef75506..6c162e9f34 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]):
     """Wraps an LruCache, adding support for Deferred results.
 
     It expects that each entry added with set() will be a Deferred; likewise get()
-    may return an ObservableDeferred.
+    will return a Deferred.
     """
 
     __slots__ = (
@@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]):
         key: KT,
         callback: Optional[Callable[[], None]] = None,
         update_metrics: bool = True,
-    ) -> Union[ObservableDeferred, VT]:
+    ) -> defer.Deferred:
         """Looks the key up in the caches.
 
+        For symmetry with set(), this method does *not* follow the synapse logcontext
+        rules: the logcontext will not be cleared on return, and the Deferred will run
+        its callbacks in the sentinel context. In other words: wrap the result with
+        make_deferred_yieldable() before `await`ing it.
+
         Args:
-            key(tuple)
-            callback(fn): Gets called when the entry in the cache is invalidated
+            key:
+            callback: Gets called when the entry in the cache is invalidated
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either an ObservableDeferred or the result itself
+            A Deferred which completes with the result. Note that this may later fail
+            if there is an ongoing set() operation which later completes with a failure.
 
         Raises:
             KeyError if the key is not found in the cache
@@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]):
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred
+            return val.deferred.observe()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]):
         if val2 is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return val2
+            return defer.succeed(val2)
 
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
@@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]):
         key: KT,
         value: defer.Deferred,
         callback: Optional[Callable[[], None]] = None,
-    ) -> ObservableDeferred:
+    ) -> defer.Deferred:
+        """Adds a new entry to the cache (or updates an existing one).
+
+        The given `value` *must* be a Deferred.
+
+        First any existing entry for the same key is invalidated. Then a new entry
+        is added to the cache for the given key.
+
+        Until the `value` completes, calls to `get()` for the key will also result in an
+        incomplete Deferred, which will ultimately complete with the same result as
+        `value`.
+
+        If `value` completes successfully, subsequent calls to `get()` will then return
+        a completed deferred with the same result. If it *fails*, the cache is
+        invalidated and subequent calls to `get()` will raise a KeyError.
+
+        If another call to `set()` happens before `value` completes, then (a) any
+        invalidation callbacks registered in the interim will be called, (b) any
+        `get()`s in the interim will continue to complete with the result from the
+        *original* `value`, (c) any future calls to `get()` will complete with the
+        result from the *new* `value`.
+
+        It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+        if it is incomplete, it runs its callbacks in the sentinel context.
+
+        Args:
+            key: Key to be set
+            value: a deferred which will complete with a result to add to the cache
+            callback: An optional callback to be called when the entry is invalidated
+        """
         if not isinstance(value, defer.Deferred):
             raise TypeError("not a Deferred")
 
@@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]):
         if existing_entry:
             existing_entry.invalidate()
 
+        # XXX: why don't we invalidate the entry in `self.cache` yet?
+
         self._pending_deferred_cache[key] = entry
 
         def compare_and_pop():
@@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache to the real cache.
         #
         observer.addCallbacks(cb, eb)
-        return observable
+
+        # we return a new Deferred which will be called before any subsequent observers.
+        return observable.observe()
 
     def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
         callbacks = [callback] if callback else []
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..a4172345ef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.deferred_cache import DeferredCache
 
 logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[Tuple, Any]
+        )  # type: DeferredCache[CacheKey, Any]
 
         def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into
@@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
                 kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
 
             try:
-                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
-                if isinstance(cached_result_d, ObservableDeferred):
-                    observer = cached_result_d.observe()
-                else:
-                    observer = defer.succeed(cached_result_d)
-
+                ret = cache.get(cache_key, callback=invalidate_callback)
             except KeyError:
                 ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                ret = cache.set(cache_key, ret, callback=invalidate_callback)
 
-                def onErr(f):
-                    cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
-                observer = result_d.observe()
-
-            return make_deferred_yieldable(observer)
+            return make_deferred_yieldable(ret)
 
         wrapped = cast(_CachedFunction, _wrapped)
 
@@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache
+        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
             for arg in list_args:
                 try:
                     res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not isinstance(res, ObservableDeferred):
-                        results[arg] = res
-                    elif not res.has_succeeded():
-                        res = res.observe()
+                    if not res.called:
                         res.addCallback(update_results_dict, arg)
                         cached_defers.append(res)
                     else:
-                        results[arg] = res.get_result()
+                        results[arg] = res.result
                 except KeyError:
                     missing.add(arg)
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8e69b1e9cc..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,237 +15,9 @@
 # limitations under the License.
 
 
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached
-
 from tests import unittest
 
 
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
-    @defer.inlineCallbacks
-    def test_passthrough(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        a = A()
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals((yield a.func("bar")), "bar")
-
-    @defer.inlineCallbacks
-    def test_hit(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals(callcount[0], 1)
-
-    @defer.inlineCallbacks
-    def test_invalidate(self):
-        callcount = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 1)
-
-        a.func.invalidate(("foo",))
-
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-
-    def test_invalidate_missing(self):
-        class A:
-            @cached()
-            def func(self, key):
-                return key
-
-        A().func.invalidate(("what",))
-
-    @defer.inlineCallbacks
-    def test_max_entries(self):
-        callcount = [0]
-
-        class A:
-            @cached(max_entries=10)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-        a = A()
-
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertEquals(callcount[0], 12)
-
-        # There must have been at least 2 evictions, meaning if we calculate
-        # all 12 values again, we must get called at least 2 more times
-        for k in range(0, 12):
-            yield a.func(k)
-
-        self.assertTrue(
-            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
-        )
-
-    def test_prefill(self):
-        callcount = [0]
-
-        d = defer.succeed(123)
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return d
-
-        a = A()
-
-        a.func.prefill(("foo",), ObservableDeferred(d))
-
-        self.assertEquals(a.func("foo").result, d.result)
-        self.assertEquals(callcount[0], 0)
-
-    @defer.inlineCallbacks
-    def test_invalidate_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func.invalidate(("foo",))
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 1)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-    @defer.inlineCallbacks
-    def test_eviction_context(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached(max_entries=2)
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        yield a.func2("foo")
-        yield a.func2("foo2")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func("foo3")
-
-        self.assertEquals(callcount[0], 3)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 4)
-        self.assertEquals(callcount2[0], 3)
-
-    @defer.inlineCallbacks
-    def test_double_get(self):
-        callcount = [0]
-        callcount2 = [0]
-
-        class A:
-            @cached()
-            def func(self, key):
-                callcount[0] += 1
-                return key
-
-            @cached(cache_context=True)
-            def func2(self, key, cache_context):
-                callcount2[0] += 1
-                return self.func(key, on_invalidate=cache_context.invalidate)
-
-        a = A()
-        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
-
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
-        yield a.func2("foo")
-        a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 2)
-
-        a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
-        yield a.func("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
-
-        yield a.func2("foo")
-
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 3)
-
-
 class UpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.storage = hs.get_datastore()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 8a08ab6661..dadfabd46d 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import unittest
 from functools import partial
 
 from twisted.internet import defer
 
 from synapse.util.caches.deferred_cache import DeferredCache
 
+from tests.unittest import TestCase
 
-class DeferredCacheTestCase(unittest.TestCase):
+
+class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
         failed = False
@@ -36,7 +37,102 @@ class DeferredCacheTestCase(unittest.TestCase):
         cache = DeferredCache("test")
         cache.prefill("foo", 123)
 
-        self.assertEquals(cache.get("foo"), 123)
+        self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+    def test_hit_deferred(self):
+        cache = DeferredCache("test")
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d)
+
+        # get should return an incomplete deferred
+        get_d = cache.get("k1")
+        self.assertFalse(get_d.called)
+
+        # add a callback that will make sure that the set_d gets called before the get_d
+        def check1(r):
+            self.assertTrue(set_d.called)
+            return r
+
+        # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+        #   maybe we should fix that?
+        # get_d.addCallback(check1)
+
+        # now fire off all the deferreds
+        origin_d.callback(99)
+        self.assertEqual(self.successResultOf(origin_d), 99)
+        self.assertEqual(self.successResultOf(set_d), 99)
+        self.assertEqual(self.successResultOf(get_d), 99)
+
+    def test_callbacks(self):
+        """Invalidation callbacks are called at the right time"""
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # we don't expect the invalidation callback for the original value to have
+        # been called yet, even though get() will now return a different result.
+        # I'm not sure if that is by design or not.
+        self.assertEqual(callbacks, set())
+
+        # now fire off all the deferreds
+        origin_d.callback(20)
+        self.assertEqual(self.successResultOf(set_d), 20)
+        self.assertEqual(self.successResultOf(get_d), 20)
+
+        # now the original invalidation callback should have been called, but none of
+        # the others
+        self.assertEqual(callbacks, {"prefill"})
+        callbacks.clear()
+
+        # another update should invalidate both the previous results
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"set", "get"})
+
+    def test_set_fail(self):
+        cache = DeferredCache("test")
+        callbacks = set()
+
+        # start with an entry, with a callback
+        cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+        # now replace that entry with a pending result
+        origin_d = defer.Deferred()
+        set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+        # ... and also make a get request
+        get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+        # none of the callbacks should have been called yet
+        self.assertEqual(callbacks, set())
+
+        # oh noes! fails!
+        e = Exception("oops")
+        origin_d.errback(e)
+        self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+        self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+        # the callbacks for the failed requests should have been called.
+        # I'm not sure if this is deliberate or not.
+        self.assertEqual(callbacks, {"get", "set"})
+        callbacks.clear()
+
+        # the old value should still be returned now?
+        get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+        self.assertEqual(self.successResultOf(get_d2), 10)
+
+        # replacing the value now should run the callbacks for those requests
+        # which got the original result
+        cache.prefill("k1", 30)
+        self.assertEqual(callbacks, {"prefill", "get2"})
 
     def test_get_immediate(self):
         cache = DeferredCache("test")
@@ -82,16 +178,15 @@ class DeferredCacheTestCase(unittest.TestCase):
         d2 = defer.Deferred()
         cache.set("key2", d2, partial(record_callback, 1))
 
-        # lookup should return observable deferreds
-        self.assertFalse(cache.get("key1").has_called())
-        self.assertFalse(cache.get("key2").has_called())
+        # lookup should return pending deferreds
+        self.assertFalse(cache.get("key1").called)
+        self.assertFalse(cache.get("key2").called)
 
         # let one of the lookups complete
         d2.callback("result2")
 
-        # for now at least, the cache will return real results rather than an
-        # observabledeferred
-        self.assertEqual(cache.get("key2"), "result2")
+        # now the cache will return a completed deferred
+        self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
 
         # now do the invalidation
         cache.invalidate_all()
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3d1f960869..2ad08f541b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Set
 
 import mock
 
@@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_cache_with_async_exception(self):
+        """The wrapped function returns a failure
+        """
+
+        class Cls:
+            result = None
+            call_count = 0
+
+            @cached()
+            def fn(self, arg1):
+                self.call_count += 1
+                return self.result
+
+        obj = Cls()
+        callbacks = set()  # type: Set[str]
+
+        # set off an asynchronous request
+        obj.result = origin_d = defer.Deferred()
+
+        d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+        self.assertFalse(d1.called)
+
+        # a second request should also return a deferred, but should not call the
+        # function itself.
+        d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+        self.assertFalse(d2.called)
+        self.assertEqual(obj.call_count, 1)
+
+        # no callbacks yet
+        self.assertEqual(callbacks, set())
+
+        # the original request fails
+        e = Exception("bzz")
+        origin_d.errback(e)
+
+        # ... which should cause the lookups to fail similarly
+        self.assertIs(self.failureResultOf(d1, Exception).value, e)
+        self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+        # ... and the callbacks to have been, uh, called.
+        self.assertEqual(callbacks, {"d1", "d2"})
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should work as normal
+        obj.result = defer.succeed(100)
+        d3 = obj.fn(1)
+        self.assertEqual(self.successResultOf(d3), 100)
+        self.assertEqual(obj.call_count, 2)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
@@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d, SynapseError)
 
 
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+    """More tests for @cached
+
+    The following is a set of tests that got lost in a different file for a while.
+
+    There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+    duplicates would be removed and the two sets of classes combined.
+    """
+
+    @defer.inlineCallbacks
+    def test_passthrough(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
+
+    @defer.inlineCallbacks
+    def test_hit(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals(callcount[0], 1)
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        callcount = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 1)
+
+        a.func.invalidate(("foo",))
+
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+
+    def test_invalidate_missing(self):
+        class A:
+            @cached()
+            def func(self, key):
+                return key
+
+        A().func.invalidate(("what",))
+
+    @defer.inlineCallbacks
+    def test_max_entries(self):
+        callcount = [0]
+
+        class A:
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertEquals(callcount[0], 12)
+
+        # There must have been at least 2 evictions, meaning if we calculate
+        # all 12 values again, we must get called at least 2 more times
+        for k in range(0, 12):
+            yield a.func(k)
+
+        self.assertTrue(
+            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+        )
+
+    def test_prefill(self):
+        callcount = [0]
+
+        d = defer.succeed(123)
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return d
+
+        a = A()
+
+        a.func.prefill(("foo",), 456)
+
+        self.assertEquals(a.func("foo").result, 456)
+        self.assertEquals(callcount[0], 0)
+
+    @defer.inlineCallbacks
+    def test_invalidate_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func.invalidate(("foo",))
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 1)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+    @defer.inlineCallbacks
+    def test_eviction_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached(max_entries=2)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+        yield a.func2("foo2")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func("foo3")
+
+        self.assertEquals(callcount[0], 3)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 4)
+        self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A:
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
+
+
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cache(self):