diff --git a/changelog.d/12468.misc b/changelog.d/12468.misc
new file mode 100644
index 0000000000..3d5d25247f
--- /dev/null
+++ b/changelog.d/12468.misc
@@ -0,0 +1 @@
+Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 5eb545c86e..df1e9c1b83 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,7 +41,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -794,7 +793,7 @@ class DatabasePool:
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
- return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
+ return await delay_cancellation(_runInteraction())
async def runWithConnection(
self,
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 650e44de22..e27c5d298f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -14,6 +14,7 @@
# limitations under the License.
import abc
+import asyncio
import collections
import inspect
import itertools
@@ -25,6 +26,7 @@ from typing import (
Awaitable,
Callable,
Collection,
+ Coroutine,
Dict,
Generic,
Hashable,
@@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
return new_deferred
-def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
- """Delay cancellation of a `Deferred` until it resolves.
+@overload
+def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
+ ...
+
+
+@overload
+def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
+ ...
+
+
+@overload
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+ ...
+
+
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+ """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
- resolve with a `CancelledError` until the original `Deferred` resolves.
+ resolve with a `CancelledError` until the original awaitable resolves.
Args:
- deferred: The `Deferred` to protect against cancellation. May optionally follow
- the Synapse logcontext rules.
+ deferred: The coroutine or `Deferred` to protect against cancellation. May
+ optionally follow the Synapse logcontext rules.
Returns:
- A new `Deferred`, which will contain the result of the original `Deferred`.
- The new `Deferred` will not propagate cancellation through to the original.
- When cancelled, the new `Deferred` will wait until the original `Deferred`
- resolves before failing with a `CancelledError`.
+ A new `Deferred`, which will contain the result of the original coroutine or
+ `Deferred`. The new `Deferred` will not propagate cancellation through to the
+ original coroutine or `Deferred`.
- The new `Deferred` will follow the Synapse logcontext rules if `deferred`
+ When cancelled, the new `Deferred` will wait until the original coroutine or
+ `Deferred` resolves before failing with a `CancelledError`.
+
+ The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`.
"""
+ # First, convert the awaitable into a `Deferred`.
+ if isinstance(awaitable, defer.Deferred):
+ deferred = awaitable
+ elif asyncio.iscoroutine(awaitable):
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ deferred = defer.ensureDeferred(awaitable)
+ else:
+ # We have no idea what to do with this awaitable.
+ # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
+ # `make_awaitable`, and let the caller `await` it normally.
+ return awaitable
+
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index e5bc416de1..daacc54c72 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
- def test_cancellation(self):
+ def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
+ def test_coroutine_cancellation(self):
+ """Test that cancellation of the new `Deferred` waits for the original."""
+ blocking_deferred: "Deferred[None]" = Deferred()
+ completion_deferred: "Deferred[None]" = Deferred()
+
+ async def task():
+ await blocking_deferred
+ completion_deferred.callback(None)
+ # Raise an exception. Twisted should consume it, otherwise unwanted
+ # tracebacks will be printed in logs.
+ raise ValueError("abc")
+
+ wrapper_deferred = delay_cancellation(task())
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ blocking_deferred.called, "Cancellation was propagated too deep"
+ )
+ self.assertFalse(completion_deferred.called)
+
+ # Unblock the task.
+ blocking_deferred.callback(None)
+ self.assertTrue(completion_deferred.called)
+
+ # Now that the original coroutine has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed.
@@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
async def outer():
with LoggingContext("c") as c:
try:
- await delay_cancellation(defer.ensureDeferred(inner()))
+ await delay_cancellation(inner())
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())
|