diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/async_helpers.py | 52 |
1 files changed, 42 insertions, 10 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 650e44de22..e27c5d298f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -14,6 +14,7 @@ # limitations under the License. import abc +import asyncio import collections import inspect import itertools @@ -25,6 +26,7 @@ from typing import ( Awaitable, Callable, Collection, + Coroutine, Dict, Generic, Hashable, @@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": return new_deferred -def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": - """Delay cancellation of a `Deferred` until it resolves. +@overload +def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + ... + + +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves. Has the same effect as `stop_cancellation`, but the returned `Deferred` will not - resolve with a `CancelledError` until the original `Deferred` resolves. + resolve with a `CancelledError` until the original awaitable resolves. Args: - deferred: The `Deferred` to protect against cancellation. May optionally follow - the Synapse logcontext rules. + deferred: The coroutine or `Deferred` to protect against cancellation. May + optionally follow the Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`. - The new `Deferred` will not propagate cancellation through to the original. - When cancelled, the new `Deferred` will wait until the original `Deferred` - resolves before failing with a `CancelledError`. + A new `Deferred`, which will contain the result of the original coroutine or + `Deferred`. The new `Deferred` will not propagate cancellation through to the + original coroutine or `Deferred`. - The new `Deferred` will follow the Synapse logcontext rules if `deferred` + When cancelled, the new `Deferred` will wait until the original coroutine or + `Deferred` resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `awaitable` follows the Synapse logcontext rules. Otherwise the new `Deferred` should be wrapped with `make_deferred_yieldable`. """ + # First, convert the awaitable into a `Deferred`. + if isinstance(awaitable, defer.Deferred): + deferred = awaitable + elif asyncio.iscoroutine(awaitable): + # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant + # type-checking, but we'd need Twisted >= 21.2. + deferred = defer.ensureDeferred(awaitable) + else: + # We have no idea what to do with this awaitable. + # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from + # `make_awaitable`, and let the caller `await` it normally. + return awaitable + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: # before the new deferred is cancelled, we `pause` it to stop the cancellation # propagating. we then `unpause` it once the wrapped deferred completes, to |