diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f74b81b7a2..5997603b3c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -131,6 +132,9 @@ class Cache(object):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@@ -173,33 +177,49 @@ class CacheDescriptor(object):
)
@functools.wraps(self.orig)
- @defer.inlineCallbacks
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try:
- cached_result = cache.get(*keyargs)
+ cached_result_d = cache.get(*keyargs)
+
+ observer = cached_result_d.observe()
if DEBUG_CACHES:
- actual_result = yield self.function_to_call(obj, *args, **kwargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, keyargs,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
- ret = yield self.function_to_call(obj, *args, **kwargs)
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ cache.invalidate(*keyargs)
+ return f
+
+ ret.addErrback(onErr)
+ ret = ObservableDeferred(ret, consumeErrors=False)
cache.update(sequence, *(keyargs + [ret]))
- defer.returnValue(ret)
+ return ret.observe()
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 5a1d545c96..7bf2d38bb8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set())
def callback(r):
- self._result = (True, r)
+ object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
self._observers.pop().callback(r)
@@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r
def errback(f):
- self._result = (False, f)
+ object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
self._observers.pop().errback(f)
@@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value):
setattr(self._deferred, name, value)
+
+ def __repr__(self):
+ return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
+ id(self), self._result, self._deferred,
+ )
|