diff options
Diffstat (limited to 'synapse/util/logcontext.py')
-rw-r--r-- | synapse/util/logcontext.py | 52 |
1 files changed, 33 insertions, 19 deletions
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a92d518b43..7e6062c1b8 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -140,6 +140,37 @@ class PreserveLoggingContext(object): ) +class _PreservingContextDeferred(defer.Deferred): + """A deferred that ensures that all callbacks and errbacks are called with + the given logging context. + """ + def __init__(self, context): + self._log_context = context + defer.Deferred.__init__(self) + + def addCallbacks(self, callback, errback=None, + callbackArgs=None, callbackKeywords=None, + errbackArgs=None, errbackKeywords=None): + callback = self._wrap_callback(callback) + errback = self._wrap_callback(errback) + return defer.Deferred.addCallbacks( + self, callback, + errback=errback, + callbackArgs=callbackArgs, + callbackKeywords=callbackKeywords, + errbackArgs=errbackArgs, + errbackKeywords=errbackKeywords, + ) + + def _wrap_callback(self, f): + def g(res, *args, **kwargs): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = self._log_context + res = f(res, *args, **kwargs) + return res + return g + + def preserve_context_over_fn(fn, *args, **kwargs): """Takes a function and invokes it with the given arguments, but removes and restores the current logging context while doing so. @@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. """ - d = defer.Deferred() - current_context = LoggingContext.current_context() - - def cb(res): - with PreserveLoggingContext(): - LoggingContext.thread_local.current_context = current_context - res = d.callback(res) - return res - - def eb(failure): - with PreserveLoggingContext(): - LoggingContext.thread_local.current_context = current_context - res = d.errback(failure) - return res - - if deferred.called: - return deferred - - deferred.addCallbacks(cb, eb) + d = _PreservingContextDeferred(current_context) + deferred.chainDeferred(d) return d |