diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc
new file mode 100644
index 0000000000..d918e9e3b1
--- /dev/null
+++ b/changelog.d/12106.misc
@@ -0,0 +1 @@
+Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index a83296a229..81320b8972 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
return value
return DoneAwaitable(value)
+
+
+def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+ """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
+
+ Args:
+ deferred: The `Deferred` to protect against cancellation. Must not follow the
+ Synapse logcontext rules.
+
+ Returns:
+ A new `Deferred`, which will contain the result of the original `Deferred`,
+ but will not propagate cancellation through to the original. When cancelled,
+ the new `Deferred` will fail with a `CancelledError` and will not follow the
+ Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
+ the new `Deferred`.
+ """
+ new_deferred: defer.Deferred[T] = defer.Deferred()
+ deferred.chainDeferred(new_deferred)
+ return new_deferred
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index cce8d595fc..362014f4cb 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
+ stop_cancellation,
timeout_deferred,
)
@@ -282,3 +283,47 @@ class ConcurrentlyExecuteTest(TestCase):
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
+
+
+class StopCancellationTests(TestCase):
+ """Tests for the `stop_cancellation` function."""
+
+ def test_succeed(self):
+ """Test that the new `Deferred` receives the result."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Success should propagate through.
+ deferred.callback("success")
+ self.assertTrue(wrapper_deferred.called)
+ self.assertEqual("success", self.successResultOf(wrapper_deferred))
+
+ def test_failure(self):
+ """Test that the new `Deferred` receives the `Failure`."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Failure should propagate through.
+ deferred.errback(ValueError("abc"))
+ self.assertTrue(wrapper_deferred.called)
+ self.failureResultOf(wrapper_deferred, ValueError)
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ def test_cancellation(self):
+ """Test that cancellation of the new `Deferred` leaves the original running."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = stop_cancellation(deferred)
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertTrue(wrapper_deferred.called)
+ self.failureResultOf(wrapper_deferred, CancelledError)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled."
+ )
+
+ # Now make the inner `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
|