summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12106.misc1
-rw-r--r--synapse/util/async_helpers.py19
-rw-r--r--tests/util/test_async_helpers.py45
3 files changed, 65 insertions, 0 deletions
diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc
new file mode 100644
index 0000000000..d918e9e3b1
--- /dev/null
+++ b/changelog.d/12106.misc
@@ -0,0 +1 @@
+Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index a83296a229..81320b8972 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
         return value
 
     return DoneAwaitable(value)
+
+
+def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+    """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
+
+    Args:
+        deferred: The `Deferred` to protect against cancellation. Must not follow the
+            Synapse logcontext rules.
+
+    Returns:
+        A new `Deferred`, which will contain the result of the original `Deferred`,
+        but will not propagate cancellation through to the original. When cancelled,
+        the new `Deferred` will fail with a `CancelledError` and will not follow the
+        Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
+        the new `Deferred`.
+    """
+    new_deferred: defer.Deferred[T] = defer.Deferred()
+    deferred.chainDeferred(new_deferred)
+    return new_deferred
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index cce8d595fc..362014f4cb 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
 from synapse.util.async_helpers import (
     ObservableDeferred,
     concurrently_execute,
+    stop_cancellation,
     timeout_deferred,
 )
 
@@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase):
         d2 = ensureDeferred(caller())
         d1.callback(0)
         self.successResultOf(d2)
+
+
+class StopCancellationTests(TestCase):
+    """Tests for the `stop_cancellation` function."""
+
+    def test_succeed(self):
+        """Test that the new `Deferred` receives the result."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Success should propagate through.
+        deferred.callback("success")
+        self.assertTrue(wrapper_deferred.called)
+        self.assertEqual("success", self.successResultOf(wrapper_deferred))
+
+    def test_failure(self):
+        """Test that the new `Deferred` receives the `Failure`."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Failure should propagate through.
+        deferred.errback(ValueError("abc"))
+        self.assertTrue(wrapper_deferred.called)
+        self.failureResultOf(wrapper_deferred, ValueError)
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+    def test_cancellation(self):
+        """Test that cancellation of the new `Deferred` leaves the original running."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Cancel the new `Deferred`.
+        wrapper_deferred.cancel()
+        self.assertTrue(wrapper_deferred.called)
+        self.failureResultOf(wrapper_deferred, CancelledError)
+        self.assertFalse(
+            deferred.called, "Original `Deferred` was unexpectedly cancelled."
+        )
+
+        # Now make the inner `Deferred` fail.
+        # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+        # in logs.
+        deferred.errback(ValueError("abc"))
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")