diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 46 |
1 files changed, 33 insertions, 13 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f74b81b7a2..5997603b3c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,6 +15,7 @@ import logging from synapse.api.errors import StoreError +from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -131,6 +132,9 @@ class Cache(object): class CacheDescriptor(object): """ A method decorator that applies a memoizing cache around the function. + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + The function is presumed to take zero or more arguments, which are used in a tuple as the key for the cache. Hits are served directly from the cache; misses use the function body to generate the value. @@ -173,33 +177,49 @@ class CacheDescriptor(object): ) @functools.wraps(self.orig) - @defer.inlineCallbacks def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: - cached_result = cache.get(*keyargs) + cached_result_d = cache.get(*keyargs) + + observer = cached_result_d.observe() if DEBUG_CACHES: - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, keyargs, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield self.function_to_call(obj, *args, **kwargs) + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + cache.invalidate(*keyargs) + return f + + ret.addErrback(onErr) + ret = ObservableDeferred(ret, consumeErrors=False) cache.update(sequence, *(keyargs + [ret])) - defer.returnValue(ret) + return ret.observe() wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all |