diff options
Diffstat (limited to 'tests/util')
-rw-r--r-- | tests/util/test_async_helpers.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index e5bc416de1..daacc54c72 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -382,7 +382,7 @@ class StopCancellationTests(TestCase): class DelayCancellationTests(TestCase): """Tests for the `delay_cancellation` function.""" - def test_cancellation(self): + def test_deferred_cancellation(self): """Test that cancellation of the new `Deferred` waits for the original.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) + def test_coroutine_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + blocking_deferred: "Deferred[None]" = Deferred() + completion_deferred: "Deferred[None]" = Deferred() + + async def task(): + await blocking_deferred + completion_deferred.callback(None) + # Raise an exception. Twisted should consume it, otherwise unwanted + # tracebacks will be printed in logs. + raise ValueError("abc") + + wrapper_deferred = delay_cancellation(task()) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + blocking_deferred.called, "Cancellation was propagated too deep" + ) + self.assertFalse(completion_deferred.called) + + # Unblock the task. + blocking_deferred.callback(None) + self.assertTrue(completion_deferred.called) + + # Now that the original coroutine has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + def test_suppresses_second_cancellation(self): """Test that a second cancellation is suppressed. @@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase): async def outer(): with LoggingContext("c") as c: try: - await delay_cancellation(defer.ensureDeferred(inner())) + await delay_cancellation(inner()) self.fail("`CancelledError` was not raised") except CancelledError: self.assertEqual(c, current_context()) |