summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <seanq@element.io>2022-03-08 14:50:10 +0000
committerSean Quah <seanq@element.io>2022-03-08 17:14:56 +0000
commitc3ed4017542012c22521aeab8d7523f7fceb18dd (patch)
treef40f3c8f37b11fcea91bbcfc93ef1e0f0cedcdc2
parentAdd newsfile (diff)
downloadsynapse-c3ed4017542012c22521aeab8d7523f7fceb18dd.tar.xz
Add `delay_cancellation` utility function
`delay_cancellation` behaves like `stop_cancellation`, except it
delays `CancelledError`s until the original `Deferred` resolves.
This is handy for unifying cleanup paths and ensuring that uncancelled
coroutines don't use finished logcontexts.

Signed-off-by: Sean Quah <seanq@element.io>
-rw-r--r--synapse/util/async_helpers.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bb777e7613..3eb1a2d8e2 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -714,3 +714,59 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
     new_deferred: defer.Deferred[T] = defer.Deferred()
     deferred.chainDeferred(new_deferred)
     return new_deferred
+
+
+def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]":
+    """Delay cancellation of a `Deferred` until it resolves.
+
+    Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
+    resolve with a `CancelledError` until the original `Deferred` resolves.
+
+    Args:
+        deferred: The `Deferred` to protect against cancellation. Must not follow the
+            Synapse logcontext rules if `all` is `False`.
+        all: `True` to delay multiple cancellations. `False` to delay only the first
+            cancellation.
+
+    Returns:
+        A new `Deferred`, which will contain the result of the original `Deferred`.
+        The new `Deferred` will not propagate cancellation through to the original.
+        When cancelled, the new `Deferred` will wait until the original `Deferred`
+        resolves before failing with a `CancelledError`.
+
+        The new `Deferred` will only follow the Synapse logcontext rules if `all` is
+        `True` and `deferred` follows the Synapse logcontext rules. Otherwise the new
+        `Deferred` should be wrapped with `make_deferred_yieldable`.
+    """
+
+    def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]:
+        """Insert another `Deferred` into the chain to delay cancellation.
+
+        Called when the original `Deferred` resolves or the new `Deferred` is
+        cancelled.
+        """
+        failure.trap(CancelledError)
+
+        if deferred.called and not deferred.paused:
+            # The `CancelledError` came from the original `Deferred`. Pass it through.
+            return failure
+
+        # Construct another `Deferred` that will only fail with the `CancelledError`
+        # once the original `Deferred` resolves.
+        delay_deferred: "defer.Deferred[T]" = defer.Deferred()
+        deferred.chainDeferred(delay_deferred)
+
+        if all:
+            # Intercept cancellations recursively. Each cancellation will cause another
+            # `Deferred` to be inserted into the chain.
+            delay_deferred.addErrback(cancel_errback)
+
+        # Override the result with the `CancelledError`.
+        delay_deferred.addBoth(lambda _: failure)
+
+        return delay_deferred
+
+    new_deferred: "defer.Deferred[T]" = defer.Deferred()
+    deferred.chainDeferred(new_deferred)
+    new_deferred.addErrback(cancel_errback)
+    return new_deferred