1 files changed, 61 insertions, 0 deletions
diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py
index 65a330a0e9..9ffe209c4d 100644
--- a/tests/util/test_log_context.py
+++ b/tests/util/test_log_context.py
@@ -1,8 +1,10 @@
+import twisted.python.failure
from twisted.internet import defer
from twisted.internet import reactor
from .. import unittest
from synapse.util.async import sleep
+from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase):
context_one.test_key = "one"
yield sleep(0)
self._check_test_key("one")
+
+ def _test_preserve_fn(self, function):
+ sentinel_context = LoggingContext.current_context()
+
+ callback_completed = [False]
+
+ @defer.inlineCallbacks
+ def cb():
+ context_one.test_key = "one"
+ yield function()
+ self._check_test_key("one")
+
+ callback_completed[0] = True
+
+ with LoggingContext() as context_one:
+ context_one.test_key = "one"
+
+ # fire off function, but don't wait on it.
+ logcontext.preserve_fn(cb)()
+
+ self._check_test_key("one")
+
+ # now wait for the function under test to have run, and check that
+ # the logcontext is left in a sane state.
+ d2 = defer.Deferred()
+
+ def check_logcontext():
+ if not callback_completed[0]:
+ reactor.callLater(0.01, check_logcontext)
+ return
+
+ # make sure that the context was reset before it got thrown back
+ # into the reactor
+ try:
+ self.assertIs(LoggingContext.current_context(),
+ sentinel_context)
+ d2.callback(None)
+ except BaseException:
+ d2.errback(twisted.python.failure.Failure())
+
+ reactor.callLater(0.01, check_logcontext)
+
+ # test is done once d2 finishes
+ return d2
+
+ def test_preserve_fn_with_blocking_fn(self):
+ @defer.inlineCallbacks
+ def blocking_function():
+ yield sleep(0)
+
+ return self._test_preserve_fn(blocking_function)
+
+ def test_preserve_fn_with_non_blocking_fn(self):
+ @defer.inlineCallbacks
+ def nonblocking_function():
+ with logcontext.PreserveLoggingContext():
+ yield defer.succeed(None)
+
+ return self._test_preserve_fn(nonblocking_function)
|