summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/logcontext.py60
-rw-r--r--tests/util/test_logcontext.py66
2 files changed, 93 insertions, 33 deletions
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 01ac71e53e..e086e12213 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -302,7 +302,7 @@ def preserve_fn(f):
 def run_in_background(f, *args, **kwargs):
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
-    deferred returned by the funtion completes.
+    deferred returned by the function completes.
 
     Useful for wrapping functions that return a deferred which you don't yield
     on (for instance because you want to pass it to deferred.gatherResults()).
@@ -320,24 +320,31 @@ def run_in_background(f, *args, **kwargs):
         # by synchronous exceptions, so let's turn them into Failures.
         return defer.fail()
 
-    if isinstance(res, defer.Deferred) and not res.called:
-        # The function will have reset the context before returning, so
-        # we need to restore it now.
-        LoggingContext.set_current_context(current)
-
-        # The original context will be restored when the deferred
-        # completes, but there is nothing waiting for it, so it will
-        # get leaked into the reactor or some other function which
-        # wasn't expecting it. We therefore need to reset the context
-        # here.
-        #
-        # (If this feels asymmetric, consider it this way: we are
-        # effectively forking a new thread of execution. We are
-        # probably currently within a ``with LoggingContext()`` block,
-        # which is supposed to have a single entry and exit point. But
-        # by spawning off another deferred, we are effectively
-        # adding a new exit point.)
-        res.addBoth(_set_context_cb, LoggingContext.sentinel)
+    if not isinstance(res, defer.Deferred):
+        return res
+
+    if res.called and not res.paused:
+        # The function should have maintained the logcontext, so we can
+        # optimise out the messing about
+        return res
+
+    # The function may have reset the context before returning, so
+    # we need to restore it now.
+    ctx = LoggingContext.set_current_context(current)
+
+    # The original context will be restored when the deferred
+    # completes, but there is nothing waiting for it, so it will
+    # get leaked into the reactor or some other function which
+    # wasn't expecting it. We therefore need to reset the context
+    # here.
+    #
+    # (If this feels asymmetric, consider it this way: we are
+    # effectively forking a new thread of execution. We are
+    # probably currently within a ``with LoggingContext()`` block,
+    # which is supposed to have a single entry and exit point. But
+    # by spawning off another deferred, we are effectively
+    # adding a new exit point.)
+    res.addBoth(_set_context_cb, ctx)
     return res
 
 
@@ -354,9 +361,18 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
-    if isinstance(deferred, defer.Deferred) and not deferred.called:
-        prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
-        deferred.addBoth(_set_context_cb, prev_context)
+    if not isinstance(deferred, defer.Deferred):
+        return deferred
+
+    if deferred.called and not deferred.paused:
+        # it looks like this deferred is ready to run any callbacks we give it
+        # immediately. We may as well optimise out the logcontext faffery.
+        return deferred
+
+    # ok, we can't be sure that a yield won't block, so let's reset the
+    # logcontext, and add a callback to the deferred to restore it.
+    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
 
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 4850722bc5..ad78d884e0 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -36,24 +36,28 @@ class LoggingContextTestCase(unittest.TestCase):
             yield sleep(0)
             self._check_test_key("one")
 
-    def _test_preserve_fn(self, function):
+    def _test_run_in_background(self, function):
         sentinel_context = LoggingContext.current_context()
 
         callback_completed = [False]
 
-        @defer.inlineCallbacks
-        def cb():
+        def test():
             context_one.request = "one"
-            yield function()
-            self._check_test_key("one")
+            d = function()
 
-            callback_completed[0] = True
+            def cb(res):
+                self._check_test_key("one")
+                callback_completed[0] = True
+                return res
+            d.addCallback(cb)
+
+            return d
 
         with LoggingContext() as context_one:
             context_one.request = "one"
 
             # fire off function, but don't wait on it.
-            logcontext.preserve_fn(cb)()
+            logcontext.run_in_background(test)
 
             self._check_test_key("one")
 
@@ -80,20 +84,30 @@ class LoggingContextTestCase(unittest.TestCase):
         # test is done once d2 finishes
         return d2
 
-    def test_preserve_fn_with_blocking_fn(self):
+    def test_run_in_background_with_blocking_fn(self):
         @defer.inlineCallbacks
         def blocking_function():
             yield sleep(0)
 
-        return self._test_preserve_fn(blocking_function)
+        return self._test_run_in_background(blocking_function)
 
-    def test_preserve_fn_with_non_blocking_fn(self):
+    def test_run_in_background_with_non_blocking_fn(self):
         @defer.inlineCallbacks
         def nonblocking_function():
             with logcontext.PreserveLoggingContext():
                 yield defer.succeed(None)
 
-        return self._test_preserve_fn(nonblocking_function)
+        return self._test_run_in_background(nonblocking_function)
+
+    def test_run_in_background_with_chained_deferred(self):
+        # a function which returns a deferred which looks like it has been
+        # called, but is actually paused
+        def testfunc():
+            return logcontext.make_deferred_yieldable(
+                _chained_deferred_function()
+            )
+
+        return self._test_run_in_background(testfunc)
 
     @defer.inlineCallbacks
     def test_make_deferred_yieldable(self):
@@ -119,6 +133,22 @@ class LoggingContextTestCase(unittest.TestCase):
             self._check_test_key("one")
 
     @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_chained_deferreds(self):
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
+    @defer.inlineCallbacks
     def test_make_deferred_yieldable_on_non_deferred(self):
         """Check that make_deferred_yieldable does the right thing when its
         argument isn't actually a deferred"""
@@ -132,3 +162,17 @@ class LoggingContextTestCase(unittest.TestCase):
             r = yield d1
             self.assertEqual(r, "bum")
             self._check_test_key("one")
+
+
+# a function which returns a deferred which has been "called", but
+# which had a function which returned another incomplete deferred on
+# its callback list, so won't yet call any other new callbacks.
+def _chained_deferred_function():
+    d = defer.succeed(None)
+
+    def cb(res):
+        d2 = defer.Deferred()
+        reactor.callLater(0, d2.callback, res)
+        return d2
+    d.addCallback(cb)
+    return d