diff options
author | Sean Quah <8349537+squahtx@users.noreply.github.com> | 2022-04-22 18:20:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-22 18:20:06 +0100 |
commit | a50fb411b366576c3d4796b04c45385bdb7a6897 (patch) | |
tree | 1c41b7d0c634674f6a47e4edb8fc364c6c76b6fd /synapse | |
parent | MSC3202: Fix device_unused_fallback_keys -> device_unused_fallback_key_types ... (diff) | |
download | synapse-a50fb411b366576c3d4796b04c45385bdb7a6897.tar.xz |
Update `delay_cancellation` to accept any awaitable (#12468)
This will mainly be useful when dealing with module callbacks, which are all typed as returning `Awaitable`s instead of coroutines or `Deferred`s. Signed-off-by: Sean Quah <seanq@element.io>
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/database.py | 3 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 52 |
2 files changed, 43 insertions, 12 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 5eb545c86e..df1e9c1b83 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,7 +41,6 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi -from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -794,7 +793,7 @@ class DatabasePool: # We also wait until everything above is done before releasing the # `CancelledError`, so that logging contexts won't get used after they have been # finished. - return await delay_cancellation(defer.ensureDeferred(_runInteraction())) + return await delay_cancellation(_runInteraction()) async def runWithConnection( self, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 650e44de22..e27c5d298f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -14,6 +14,7 @@ # limitations under the License. import abc +import asyncio import collections import inspect import itertools @@ -25,6 +26,7 @@ from typing import ( Awaitable, Callable, Collection, + Coroutine, Dict, Generic, Hashable, @@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": return new_deferred -def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": - """Delay cancellation of a `Deferred` until it resolves. +@overload +def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": + ... + + +@overload +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + ... + + +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: + """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves. Has the same effect as `stop_cancellation`, but the returned `Deferred` will not - resolve with a `CancelledError` until the original `Deferred` resolves. + resolve with a `CancelledError` until the original awaitable resolves. Args: - deferred: The `Deferred` to protect against cancellation. May optionally follow - the Synapse logcontext rules. + deferred: The coroutine or `Deferred` to protect against cancellation. May + optionally follow the Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`. - The new `Deferred` will not propagate cancellation through to the original. - When cancelled, the new `Deferred` will wait until the original `Deferred` - resolves before failing with a `CancelledError`. + A new `Deferred`, which will contain the result of the original coroutine or + `Deferred`. The new `Deferred` will not propagate cancellation through to the + original coroutine or `Deferred`. - The new `Deferred` will follow the Synapse logcontext rules if `deferred` + When cancelled, the new `Deferred` will wait until the original coroutine or + `Deferred` resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `awaitable` follows the Synapse logcontext rules. Otherwise the new `Deferred` should be wrapped with `make_deferred_yieldable`. """ + # First, convert the awaitable into a `Deferred`. + if isinstance(awaitable, defer.Deferred): + deferred = awaitable + elif asyncio.iscoroutine(awaitable): + # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant + # type-checking, but we'd need Twisted >= 21.2. + deferred = defer.ensureDeferred(awaitable) + else: + # We have no idea what to do with this awaitable. + # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from + # `make_awaitable`, and let the caller `await` it normally. + return awaitable + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: # before the new deferred is cancelled, we `pause` it to stop the cancellation # propagating. we then `unpause` it once the wrapped deferred completes, to |