Update `delay_cancellation` to accept any awaitable (#12468)
This will mainly be useful when dealing with module callbacks, which are
all typed as returning `Awaitable`s instead of coroutines or
`Deferred`s.
Signed-off-by: Sean Quah <seanq@element.io>
1 files changed, 31 insertions, 2 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index e5bc416de1..daacc54c72 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
- def test_cancellation(self):
+ def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
+ def test_coroutine_cancellation(self):
+ """Test that cancellation of the new `Deferred` waits for the original."""
+ blocking_deferred: "Deferred[None]" = Deferred()
+ completion_deferred: "Deferred[None]" = Deferred()
+
+ async def task():
+ await blocking_deferred
+ completion_deferred.callback(None)
+ # Raise an exception. Twisted should consume it, otherwise unwanted
+ # tracebacks will be printed in logs.
+ raise ValueError("abc")
+
+ wrapper_deferred = delay_cancellation(task())
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ blocking_deferred.called, "Cancellation was propagated too deep"
+ )
+ self.assertFalse(completion_deferred.called)
+
+ # Unblock the task.
+ blocking_deferred.callback(None)
+ self.assertTrue(completion_deferred.called)
+
+ # Now that the original coroutine has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed.
@@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
async def outer():
with LoggingContext("c") as c:
try:
- await delay_cancellation(defer.ensureDeferred(inner()))
+ await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())
|