summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/federation.py5
-rw-r--r--synapse/util/logcontext.py61
-rw-r--r--tests/util/test_log_context.py61
3 files changed, 93 insertions, 34 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0cd5501b05..6157204924 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -933,8 +933,9 @@ class FederationHandler(BaseHandler):
             # lots of requests for missing prev_events which we do actually
             # have. Hence we fire off the deferred, but don't wait for it.
 
-            synapse.util.logcontext.reset_context_after_deferred(
-                self._handle_queued_pdus(room_queue))
+            synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
+                room_queue
+            )
 
         defer.returnValue(True)
 
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d73670f9f2..7cbe390b15 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -308,47 +308,44 @@ def preserve_context_over_deferred(deferred, context=None):
     return d
 
 
-def reset_context_after_deferred(deferred):
-    """If the deferred is incomplete, add a callback which will reset the
-    context.
-
-    This is useful when you want to fire off a deferred, but don't want to
-    wait for it to complete. (The deferred will restore the current log context
-    when it completes, so if you don't do anything, it will leak log context.)
-
-    (If this feels asymmetric, consider it this way: we are effectively forking
-    a new thread of execution. We are probably currently within a
-    ``with LoggingContext()`` block, which is supposed to have a single entry
-    and exit point. But by spawning off another deferred, we are effectively
-    adding a new exit point.)
+def preserve_fn(f):
+    """Wraps a function, to ensure that the current context is restored after
+    return from the function, and that the sentinel context is set once the
+    deferred returned by the funtion completes.
 
-    Args:
-        deferred (defer.Deferred): deferred
+    Useful for wrapping functions that return a deferred which you don't yield
+    on.
     """
     def reset_context(result):
         LoggingContext.set_current_context(LoggingContext.sentinel)
         return result
 
-    if not deferred.called:
-        deferred.addBoth(reset_context)
-
-
-def preserve_fn(f):
-    """Ensures that function is called with correct context and that context is
-    restored after return. Useful for wrapping functions that return a deferred
-    which you don't yield on.
-    """
+    # XXX: why is this here rather than inside g? surely we want to preserve
+    # the context from the time the function was called, not when it was
+    # wrapped?
     current = LoggingContext.current_context()
 
     def g(*args, **kwargs):
-        with PreserveLoggingContext(current):
-            res = f(*args, **kwargs)
-            if isinstance(res, defer.Deferred):
-                return preserve_context_over_deferred(
-                    res, context=LoggingContext.sentinel
-                )
-            else:
-                return res
+        res = f(*args, **kwargs)
+        if isinstance(res, defer.Deferred) and not res.called:
+            # The function will have reset the context before returning, so
+            # we need to restore it now.
+            LoggingContext.set_current_context(current)
+
+            # The original context will be restored when the deferred
+            # completes, but there is nothing waiting for it, so it will
+            # get leaked into the reactor or some other function which
+            # wasn't expecting it. We therefore need to reset the context
+            # here.
+            #
+            # (If this feels asymmetric, consider it this way: we are
+            # effectively forking a new thread of execution. We are
+            # probably currently within a ``with LoggingContext()`` block,
+            # which is supposed to have a single entry and exit point. But
+            # by spawning off another deferred, we are effectively
+            # adding a new exit point.)
+            res.addBoth(reset_context)
+        return res
     return g
 
 
diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py
index 65a330a0e9..9ffe209c4d 100644
--- a/tests/util/test_log_context.py
+++ b/tests/util/test_log_context.py
@@ -1,8 +1,10 @@
+import twisted.python.failure
 from twisted.internet import defer
 from twisted.internet import reactor
 from .. import unittest
 
 from synapse.util.async import sleep
+from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
 
 
@@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase):
             context_one.test_key = "one"
             yield sleep(0)
             self._check_test_key("one")
+
+    def _test_preserve_fn(self, function):
+        sentinel_context = LoggingContext.current_context()
+
+        callback_completed = [False]
+
+        @defer.inlineCallbacks
+        def cb():
+            context_one.test_key = "one"
+            yield function()
+            self._check_test_key("one")
+
+            callback_completed[0] = True
+
+        with LoggingContext() as context_one:
+            context_one.test_key = "one"
+
+            # fire off function, but don't wait on it.
+            logcontext.preserve_fn(cb)()
+
+            self._check_test_key("one")
+
+        # now wait for the function under test to have run, and check that
+        # the logcontext is left in a sane state.
+        d2 = defer.Deferred()
+
+        def check_logcontext():
+            if not callback_completed[0]:
+                reactor.callLater(0.01, check_logcontext)
+                return
+
+            # make sure that the context was reset before it got thrown back
+            # into the reactor
+            try:
+                self.assertIs(LoggingContext.current_context(),
+                              sentinel_context)
+                d2.callback(None)
+            except BaseException:
+                d2.errback(twisted.python.failure.Failure())
+
+        reactor.callLater(0.01, check_logcontext)
+
+        # test is done once d2 finishes
+        return d2
+
+    def test_preserve_fn_with_blocking_fn(self):
+        @defer.inlineCallbacks
+        def blocking_function():
+            yield sleep(0)
+
+        return self._test_preserve_fn(blocking_function)
+
+    def test_preserve_fn_with_non_blocking_fn(self):
+        @defer.inlineCallbacks
+        def nonblocking_function():
+            with logcontext.PreserveLoggingContext():
+                yield defer.succeed(None)
+
+        return self._test_preserve_fn(nonblocking_function)