summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/util/test_async_helpers.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index e5bc416de1..daacc54c72 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
 class DelayCancellationTests(TestCase):
     """Tests for the `delay_cancellation` function."""
 
-    def test_cancellation(self):
+    def test_deferred_cancellation(self):
         """Test that cancellation of the new `Deferred` waits for the original."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = delay_cancellation(deferred)
@@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
         # Now that the original `Deferred` has failed, we should get a `CancelledError`.
         self.failureResultOf(wrapper_deferred, CancelledError)
 
+    def test_coroutine_cancellation(self):
+        """Test that cancellation of the new `Deferred` waits for the original."""
+        blocking_deferred: "Deferred[None]" = Deferred()
+        completion_deferred: "Deferred[None]" = Deferred()
+
+        async def task():
+            await blocking_deferred
+            completion_deferred.callback(None)
+            # Raise an exception. Twisted should consume it, otherwise unwanted
+            # tracebacks will be printed in logs.
+            raise ValueError("abc")
+
+        wrapper_deferred = delay_cancellation(task())
+
+        # Cancel the new `Deferred`.
+        wrapper_deferred.cancel()
+        self.assertNoResult(wrapper_deferred)
+        self.assertFalse(
+            blocking_deferred.called, "Cancellation was propagated too deep"
+        )
+        self.assertFalse(completion_deferred.called)
+
+        # Unblock the task.
+        blocking_deferred.callback(None)
+        self.assertTrue(completion_deferred.called)
+
+        # Now that the original coroutine has failed, we should get a `CancelledError`.
+        self.failureResultOf(wrapper_deferred, CancelledError)
+
     def test_suppresses_second_cancellation(self):
         """Test that a second cancellation is suppressed.
 
@@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
         async def outer():
             with LoggingContext("c") as c:
                 try:
-                    await delay_cancellation(defer.ensureDeferred(inner()))
+                    await delay_cancellation(inner())
                     self.fail("`CancelledError` was not raised")
                 except CancelledError:
                     self.assertEqual(c, current_context())