diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 88e56e3302..277854ccbc 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn +) from . import caches_by_name, DEBUG_CACHES, cache_counter @@ -149,7 +152,7 @@ class CacheDescriptor(object): self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] if len(self.arg_names) < self.num_args: raise Exception( @@ -190,7 +193,7 @@ class CacheDescriptor(object): defer.returnValue(cached_result) observer.addCallback(check_result) - return observer + return preserve_context_over_deferred(observer) except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated @@ -198,6 +201,7 @@ class CacheDescriptor(object): sequence = self.cache.sequence ret = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, obj, *args, **kwargs ) @@ -211,7 +215,7 @@ class CacheDescriptor(object): ret = ObservableDeferred(ret, consumeErrors=True) self.cache.update(sequence, cache_key, ret) - return ret.observe() + return preserve_context_over_deferred(ret.observe()) wrapped.invalidate = self.cache.invalidate wrapped.invalidate_all = self.cache.invalidate_all @@ -250,7 +254,7 @@ class CacheListDescriptor(object): self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) self.cache = cache @@ -299,6 +303,7 @@ class CacheListDescriptor(object): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, **args_to_call ) @@ -308,7 +313,8 @@ class CacheListDescriptor(object): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - observer = ret_d.observe() + with PreserveLoggingContext(): + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -327,10 +333,10 @@ class CacheListDescriptor(object): cached[arg] = res - return defer.gatherResults( + return preserve_context_over_deferred(defer.gatherResults( cached.values(), consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) obj.__dict__[self.orig.__name__] = wrapped |