summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py46
-rw-r--r--synapse/util/async.py9
-rw-r--r--tests/storage/test__base.py11
3 files changed, 47 insertions, 19 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f74b81b7a2..5997603b3c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
 import logging
 
 from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
@@ -131,6 +132,9 @@ class Cache(object):
 class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
+    This caches deferreds, rather than the results themselves. Deferreds that
+    fail are removed from the cache.
+
     The function is presumed to take zero or more arguments, which are used in
     a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
@@ -173,33 +177,49 @@ class CacheDescriptor(object):
         )
 
         @functools.wraps(self.orig)
-        @defer.inlineCallbacks
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             try:
-                cached_result = cache.get(*keyargs)
+                cached_result_d = cache.get(*keyargs)
+
+                observer = cached_result_d.observe()
                 if DEBUG_CACHES:
-                    actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                    if actual_result != cached_result:
-                        logger.error(
-                            "Stale cache entry %s%r: cached: %r, actual %r",
-                            self.orig.__name__, keyargs,
-                            cached_result, actual_result,
-                        )
-                        raise ValueError("Stale cache entry")
-                defer.returnValue(cached_result)
+                    @defer.inlineCallbacks
+                    def check_result(cached_result):
+                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
+                        if actual_result != cached_result:
+                            logger.error(
+                                "Stale cache entry %s%r: cached: %r, actual %r",
+                                self.orig.__name__, keyargs,
+                                cached_result, actual_result,
+                            )
+                            raise ValueError("Stale cache entry")
+                        defer.returnValue(cached_result)
+                    observer.addCallback(check_result)
+
+                return observer
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
                 sequence = cache.sequence
 
-                ret = yield self.function_to_call(obj, *args, **kwargs)
+                ret = defer.maybeDeferred(
+                    self.function_to_call,
+                    obj, *args, **kwargs
+                )
+
+                def onErr(f):
+                    cache.invalidate(*keyargs)
+                    return f
+
+                ret.addErrback(onErr)
 
+                ret = ObservableDeferred(ret, consumeErrors=False)
                 cache.update(sequence, *(keyargs + [ret]))
 
-                defer.returnValue(ret)
+                return ret.observe()
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 5a1d545c96..7bf2d38bb8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -51,7 +51,7 @@ class ObservableDeferred(object):
         object.__setattr__(self, "_observers", set())
 
         def callback(r):
-            self._result = (True, r)
+            object.__setattr__(self, "_result", (True, r))
             while self._observers:
                 try:
                     self._observers.pop().callback(r)
@@ -60,7 +60,7 @@ class ObservableDeferred(object):
             return r
 
         def errback(f):
-            self._result = (False, f)
+            object.__setattr__(self, "_result", (False, f))
             while self._observers:
                 try:
                     self._observers.pop().errback(f)
@@ -97,3 +97,8 @@ class ObservableDeferred(object):
 
     def __setattr__(self, name, value):
         setattr(self._deferred, name, value)
+
+    def __repr__(self):
+        return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
+            id(self), self._result, self._deferred,
+        )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8c3d2952bd..8fa305d18a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.util.async import ObservableDeferred
+
 from synapse.storage._base import Cache, cached
 
 
@@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
 
-    @defer.inlineCallbacks
     def test_prefill(self):
         callcount = [0]
 
+        d = defer.succeed(123)
+
         class A(object):
             @cached()
             def func(self, key):
                 callcount[0] += 1
-                return key
+                return d
 
         a = A()
 
-        a.func.prefill("foo", 123)
+        a.func.prefill("foo", ObservableDeferred(d))
 
-        self.assertEquals((yield a.func("foo")), 123)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)