summary refs log tree commit diff
path: root/tests/util/test_logcontext.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/util/test_logcontext.py55
1 files changed, 35 insertions, 20 deletions
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8adaee3c8d..8b8455c8b7 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,8 +1,14 @@
 import twisted.python.failure
 from twisted.internet import defer, reactor
 
-from synapse.util import Clock, logcontext
-from synapse.util.logcontext import LoggingContext
+from synapse.logging.context import (
+    LoggingContext,
+    PreserveLoggingContext,
+    make_deferred_yieldable,
+    nested_logging_context,
+    run_in_background,
+)
+from synapse.util import Clock
 
 from .. import unittest
 
@@ -39,24 +45,17 @@ class LoggingContextTestCase(unittest.TestCase):
 
         callback_completed = [False]
 
-        def test():
+        with LoggingContext() as context_one:
             context_one.request = "one"
-            d = function()
+
+            # fire off function, but don't wait on it.
+            d2 = run_in_background(function)
 
             def cb(res):
-                self._check_test_key("one")
                 callback_completed[0] = True
                 return res
 
-            d.addCallback(cb)
-
-            return d
-
-        with LoggingContext() as context_one:
-            context_one.request = "one"
-
-            # fire off function, but don't wait on it.
-            logcontext.run_in_background(test)
+            d2.addCallback(cb)
 
             self._check_test_key("one")
 
@@ -92,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase):
     def test_run_in_background_with_non_blocking_fn(self):
         @defer.inlineCallbacks
         def nonblocking_function():
-            with logcontext.PreserveLoggingContext():
+            with PreserveLoggingContext():
                 yield defer.succeed(None)
 
         return self._test_run_in_background(nonblocking_function)
@@ -101,7 +100,23 @@ class LoggingContextTestCase(unittest.TestCase):
         # a function which returns a deferred which looks like it has been
         # called, but is actually paused
         def testfunc():
-            return logcontext.make_deferred_yieldable(_chained_deferred_function())
+            return make_deferred_yieldable(_chained_deferred_function())
+
+        return self._test_run_in_background(testfunc)
+
+    def test_run_in_background_with_coroutine(self):
+        async def testfunc():
+            self._check_test_key("one")
+            d = Clock(reactor).sleep(0)
+            self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+            await d
+            self._check_test_key("one")
+
+        return self._test_run_in_background(testfunc)
+
+    def test_run_in_background_with_nonblocking_coroutine(self):
+        async def testfunc():
+            self._check_test_key("one")
 
         return self._test_run_in_background(testfunc)
 
@@ -119,7 +134,7 @@ class LoggingContextTestCase(unittest.TestCase):
         with LoggingContext() as context_one:
             context_one.request = "one"
 
-            d1 = logcontext.make_deferred_yieldable(blocking_function())
+            d1 = make_deferred_yieldable(blocking_function())
             # make sure that the context was reset by make_deferred_yieldable
             self.assertIs(LoggingContext.current_context(), sentinel_context)
 
@@ -135,7 +150,7 @@ class LoggingContextTestCase(unittest.TestCase):
         with LoggingContext() as context_one:
             context_one.request = "one"
 
-            d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+            d1 = make_deferred_yieldable(_chained_deferred_function())
             # make sure that the context was reset by make_deferred_yieldable
             self.assertIs(LoggingContext.current_context(), sentinel_context)
 
@@ -152,7 +167,7 @@ class LoggingContextTestCase(unittest.TestCase):
         with LoggingContext() as context_one:
             context_one.request = "one"
 
-            d1 = logcontext.make_deferred_yieldable("bum")
+            d1 = make_deferred_yieldable("bum")
             self._check_test_key("one")
 
             r = yield d1
@@ -161,7 +176,7 @@ class LoggingContextTestCase(unittest.TestCase):
 
     def test_nested_logging_context(self):
         with LoggingContext(request="foo"):
-            nested_context = logcontext.nested_logging_context(suffix="bar")
+            nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")