Add `stop_cancellation` utility function (#12106)
1 files changed, 45 insertions, 0 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index cce8d595fc..362014f4cb 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
+ stop_cancellation,
timeout_deferred,
)
@@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase):
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
+
+
+class StopCancellationTests(TestCase):
+ """Tests for the `stop_cancellation` function."""
+
+ def test_succeed(self):
+ """Test that the new `Deferred` receives the result."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Success should propagate through.
+ deferred.callback("success")
+ self.assertTrue(wrapper_deferred.called)
+ self.assertEqual("success", self.successResultOf(wrapper_deferred))
+
+ def test_failure(self):
+ """Test that the new `Deferred` receives the `Failure`."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Failure should propagate through.
+ deferred.errback(ValueError("abc"))
+ self.assertTrue(wrapper_deferred.called)
+ self.failureResultOf(wrapper_deferred, ValueError)
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ def test_cancellation(self):
+ """Test that cancellation of the new `Deferred` leaves the original running."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertTrue(wrapper_deferred.called)
+ self.failureResultOf(wrapper_deferred, CancelledError)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled."
+ )
+
+ # Now make the inner `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
|