summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async.py9
-rw-r--r--synapse/util/logcontext.py15
2 files changed, 16 insertions, 8 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py
index c84b23ff46..347fb1e380 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
         except StopIteration:
             pass
 
-    return defer.gatherResults([
+    return preserve_context_over_deferred(defer.gatherResults([
         preserve_fn(_concurrently_execute_inner)()
         for _ in xrange(limit)
-    ], consumeErrors=True).addErrback(unwrapFirstError)
+    ], consumeErrors=True)).addErrback(unwrapFirstError)
 
 
 class Linearizer(object):
@@ -181,7 +181,8 @@ class Linearizer(object):
         self.key_to_defer[key] = new_defer
 
         if current_defer:
-            yield preserve_context_over_deferred(current_defer)
+            with PreserveLoggingContext():
+                yield current_defer
 
         @contextmanager
         def _ctx_manager():
@@ -264,7 +265,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield defer.gatherResults(to_wait_on)
+        yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 7a87045f87..6c83eb213d 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
         return res
 
 
-def preserve_context_over_deferred(deferred):
+def preserve_context_over_deferred(deferred, context=None):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
     """
-    current_context = LoggingContext.current_context()
-    d = _PreservingContextDeferred(current_context)
+    if context is None:
+        context = LoggingContext.current_context()
+    d = _PreservingContextDeferred(context)
     deferred.chainDeferred(d)
     return d
 
@@ -316,7 +317,13 @@ def preserve_fn(f):
 
     def g(*args, **kwargs):
         with PreserveLoggingContext(current):
-            return f(*args, **kwargs)
+            res = f(*args, **kwargs)
+            if isinstance(res, defer.Deferred):
+                return preserve_context_over_deferred(
+                    res, context=LoggingContext.sentinel
+                )
+            else:
+                return res
     return g