summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/logcontext.py52
1 files changed, 33 insertions, 19 deletions
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
                 )
 
 
+class _PreservingContextDeferred(defer.Deferred):
+    """A deferred that ensures that all callbacks and errbacks are called with
+    the given logging context.
+    """
+    def __init__(self, context):
+        self._log_context = context
+        defer.Deferred.__init__(self)
+
+    def addCallbacks(self, callback, errback=None,
+                     callbackArgs=None, callbackKeywords=None,
+                     errbackArgs=None, errbackKeywords=None):
+        callback = self._wrap_callback(callback)
+        errback = self._wrap_callback(errback)
+        return defer.Deferred.addCallbacks(
+            self, callback,
+            errback=errback,
+            callbackArgs=callbackArgs,
+            callbackKeywords=callbackKeywords,
+            errbackArgs=errbackArgs,
+            errbackKeywords=errbackKeywords,
+        )
+
+    def _wrap_callback(self, f):
+        def g(res, *args, **kwargs):
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = self._log_context
+                res = f(res, *args, **kwargs)
+            return res
+        return g
+
+
 def preserve_context_over_fn(fn, *args, **kwargs):
     """Takes a function and invokes it with the given arguments, but removes
     and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
     """
-    d = defer.Deferred()
-
     current_context = LoggingContext.current_context()
-
-    def cb(res):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.callback(res)
-        return res
-
-    def eb(failure):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.errback(failure)
-        return res
-
-    if deferred.called:
-        return deferred
-
-    deferred.addCallbacks(cb, eb)
+    d = _PreservingContextDeferred(current_context)
+    deferred.chainDeferred(d)
     return d