diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 19595df422..5c30ed235d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
import logging
from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
from . import DEBUG_CACHES, register_cache
@@ -328,11 +325,9 @@ class CacheDescriptor(_CacheDescriptorBase):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return preserve_context_over_deferred(observer)
except KeyError:
ret = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
obj, *args, **kwargs
)
@@ -342,10 +337,11 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, ret, callback=invalidate_callback)
+ result_d = ObservableDeferred(ret, consumeErrors=True)
+ cache.set(cache_key, result_d, callback=invalidate_callback)
+ observer = result_d.observe()
- return preserve_context_over_deferred(ret.observe())
+ return logcontext.make_deferred_yieldable(observer)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
@@ -362,7 +358,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped fucntion.
+ the list of missing keys to the wrapped function.
+
+ Once wrapped, the function returns either a Deferred which resolves to
+ the list of results, or (if all results were cached), just the list of
+ results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None,
@@ -433,8 +433,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
**args_to_call
)
@@ -443,8 +442,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- with PreserveLoggingContext():
- observer = ret_d.observe()
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -471,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
results.update(res)
return results
- return preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index ff67b1d794..857afee7cb 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
+
+ Deprecated: this almost certainly doesn't do want you want, ie make
+ the deferred follow the synapse logcontext rules: try
+ ``make_deferred_yieldable`` instead.
"""
if context is None:
context = LoggingContext.current_context()
@@ -359,6 +363,25 @@ def preserve_fn(f):
return g
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to preserve_fn.)
+ """
+ with PreserveLoggingContext():
+ r = yield deferred
+ defer.returnValue(r)
+
+
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
|