diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e27917c63a..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
|