summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 842c4e2982..d4751769e4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -149,6 +149,9 @@ class Cache(object):
 class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
+    This caches deferreds, rather than the results themselves. Deferreds that
+    fail are removed from the cache.
+
     The function is presumed to take zero or more arguments, which are used in
     a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
@@ -195,8 +198,23 @@ class CacheDescriptor(object):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             try:
-                cached_result = cache.get(*keyargs)
-                return cached_result.observe()
+                cached_result_d = cache.get(*keyargs)
+
+                if DEBUG_CACHES:
+
+                    @defer.inlineCallbacks
+                    def check_result(cached_result):
+                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
+                        if actual_result != cached_result:
+                            logger.error(
+                                "Stale cache entry %s%r: cached: %r, actual %r",
+                                self.orig.__name__, keyargs,
+                                cached_result, actual_result,
+                            )
+                            raise ValueError("Stale cache entry")
+                    cached_result_d.observe().addCallback(check_result)
+
+                return cached_result_d.observe()
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
@@ -207,8 +225,14 @@ class CacheDescriptor(object):
                     self.function_to_call,
                     obj, *args, **kwargs
                 )
-                ret = ObservableDeferred(ret, consumeErrors=False)
 
+                def onErr(f):
+                    cache.invalidate(*keyargs)
+                    return f
+
+                ret.addErrback(onErr)
+
+                ret = ObservableDeferred(ret, consumeErrors=False)
                 cache.update(sequence, *(keyargs + [ret]))
 
                 return ret.observe()