diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 5eb545c86e..df1e9c1b83 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,7 +41,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -794,7 +793,7 @@ class DatabasePool:
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
- return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
+ return await delay_cancellation(_runInteraction())
async def runWithConnection(
self,
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 650e44de22..e27c5d298f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -14,6 +14,7 @@
# limitations under the License.
import abc
+import asyncio
import collections
import inspect
import itertools
@@ -25,6 +26,7 @@ from typing import (
Awaitable,
Callable,
Collection,
+ Coroutine,
Dict,
Generic,
Hashable,
@@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
return new_deferred
-def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
- """Delay cancellation of a `Deferred` until it resolves.
+@overload
+def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
+ ...
+
+
+@overload
+def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
+ ...
+
+
+@overload
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+ ...
+
+
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+ """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
- resolve with a `CancelledError` until the original `Deferred` resolves.
+ resolve with a `CancelledError` until the original awaitable resolves.
Args:
- deferred: The `Deferred` to protect against cancellation. May optionally follow
- the Synapse logcontext rules.
+ deferred: The coroutine or `Deferred` to protect against cancellation. May
+ optionally follow the Synapse logcontext rules.
Returns:
- A new `Deferred`, which will contain the result of the original `Deferred`.
- The new `Deferred` will not propagate cancellation through to the original.
- When cancelled, the new `Deferred` will wait until the original `Deferred`
- resolves before failing with a `CancelledError`.
+ A new `Deferred`, which will contain the result of the original coroutine or
+ `Deferred`. The new `Deferred` will not propagate cancellation through to the
+ original coroutine or `Deferred`.
- The new `Deferred` will follow the Synapse logcontext rules if `deferred`
+ When cancelled, the new `Deferred` will wait until the original coroutine or
+ `Deferred` resolves before failing with a `CancelledError`.
+
+ The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`.
"""
+ # First, convert the awaitable into a `Deferred`.
+ if isinstance(awaitable, defer.Deferred):
+ deferred = awaitable
+ elif asyncio.iscoroutine(awaitable):
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ deferred = defer.ensureDeferred(awaitable)
+ else:
+ # We have no idea what to do with this awaitable.
+ # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
+ # `make_awaitable`, and let the caller `await` it normally.
+ return awaitable
+
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to
|