diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 731b0c4e59..dff5a5d262 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -27,6 +27,7 @@ from typing import (
Callable,
ContextManager,
Dict,
+ Generator,
List,
Optional,
Set,
@@ -49,7 +50,10 @@ from synapse.http.server import (
respond_with_json,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.logging.context import (
+ LoggingContext,
+ make_deferred_yieldable,
+)
from synapse.types import JsonDict
from tests.server import FakeChannel, make_request
@@ -199,7 +203,7 @@ def make_request_with_cancellation_test(
#
# We would like to trigger a cancellation at the first `await`, re-run the
# request and cancel at the second `await`, and so on. By patching
- # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+ # `Deferred.__await__`, we can intercept `await`s, track which ones we have or
# have not seen, and force them to block when they wouldn't have.
# The set of previously seen `await`s.
@@ -211,7 +215,7 @@ def make_request_with_cancellation_test(
)
for request_number in itertools.count(1):
- deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+ deferred_patch = Deferred__await__Patch(seen_awaits, request_number)
try:
with mock.patch(
@@ -250,6 +254,8 @@ def make_request_with_cancellation_test(
)
if respond_mock.called:
+ _log_for_request(request_number, "--- response finished ---")
+
# The request ran to completion and we are done with testing it.
# `respond_with_json` writes the response asynchronously, so we
@@ -311,8 +317,8 @@ def make_request_with_cancellation_test(
assert False, "unreachable" # noqa: B011
-class Deferred__next__Patch:
- """A `Deferred.__next__` patch that will intercept `await`s and force them
+class Deferred__await__Patch:
+ """A `Deferred.__await__` patch that will intercept `await`s and force them
to block once it sees a new `await`.
When done with the patch, `unblock_awaits()` must be called to clean up after any
@@ -322,7 +328,7 @@ class Deferred__next__Patch:
Usage:
seen_awaits = set()
- deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+ deferred_patch = Deferred__await__Patch(seen_awaits, 1)
try:
with deferred_patch.patch():
# do things
@@ -335,14 +341,14 @@ class Deferred__next__Patch:
"""
Args:
seen_awaits: The set of stack traces of `await`s that have been previously
- seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+ seen. When the `Deferred.__await__` patch sees a new `await`, it will add
it to the set.
request_number: The request number to log against.
"""
self._request_number = request_number
self._seen_awaits = seen_awaits
- self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
+ self._original_Deferred__await__ = Deferred.__await__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0
@@ -350,8 +356,13 @@ class Deferred__next__Patch:
# Whether we have seen a new `await` not in `seen_awaits`.
self.new_await_seen = False
+ # Whether to block new await points we see. This gets set to False once
+ # we have cancelled the request to allow things to run after
+ # cancellation.
+ self._block_new_awaits = True
+
# To force `await`s on resolved `Deferred`s to block, we make up a new
- # unresolved `Deferred` and return it out of `Deferred.__next__` /
+ # unresolved `Deferred` and return it out of `Deferred.__await__` /
# `coroutine.send()`. We have to resolve it later, in case the `await`ing
# coroutine is part of some shared processing, such as `@cached`.
self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
@@ -360,15 +371,15 @@ class Deferred__next__Patch:
self._previous_stack: List[inspect.FrameInfo] = []
def patch(self) -> ContextManager[Mock]:
- """Returns a context manager which patches `Deferred.__next__`."""
+ """Returns a context manager which patches `Deferred.__await__`."""
- def Deferred___next__(
- deferred: "Deferred[T]", value: object = None
- ) -> "Deferred[T]":
- """Intercepts `await`s on `Deferred`s and rigs them to block once we have
- seen enough of them.
+ def Deferred___await__(
+ deferred: "Deferred[T]",
+ ) -> Generator["Deferred[T]", None, T]:
+ """Intercepts calls to `__await__`, which returns a generator
+ yielding deferreds that we await on.
- `Deferred.__next__` will normally:
+ The generator for `__await__` will normally:
* return `self` if the `Deferred` is unresolved, in which case
`coroutine.send()` will return the `Deferred`, and
`_defer.inlineCallbacks` will stop running the coroutine until the
@@ -376,9 +387,43 @@ class Deferred__next__Patch:
* raise a `StopIteration(result)`, containing the result of the `await`.
* raise another exception, which will come out of the `await`.
"""
+
+ # Get the original generator.
+ gen = self._original_Deferred__await__(deferred)
+
+ # Run the generator, handling each iteration to see if we need to
+ # block.
+ try:
+ while True:
+ # We've hit a new await point (or the deferred has
+ # completed), handle it.
+ handle_next_iteration(deferred)
+
+ # Continue on.
+ yield gen.send(None)
+ except StopIteration as e:
+ # We need to convert `StopIteration` into a normal return.
+ return e.value
+
+ def handle_next_iteration(
+ deferred: "Deferred[T]",
+ ) -> None:
+ """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+ seen enough of them.
+
+ Args:
+ deferred: The deferred that we've captured and are intercepting
+ `await` calls within.
+ """
+ if not self._block_new_awaits:
+ # We're no longer blocking awaits points
+ return
+
self.awaits_seen += 1
- stack = _get_stack(skip_frames=1)
+ stack = _get_stack(
+ skip_frames=2 # Ignore this function and `Deferred___await__` in stack trace
+ )
stack_hash = _hash_stack(stack)
if stack_hash not in self._seen_awaits:
@@ -389,20 +434,29 @@ class Deferred__next__Patch:
if not self.new_await_seen:
# This `await` isn't interesting. Let it proceed normally.
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "already seen",
+ )
+
# Don't log the stack. It's been seen before in a previous run.
self._previous_stack = stack
- return self._original_Deferred___next__(deferred, value)
+ return
# We want to block at the current `await`.
if deferred.called and not deferred.paused:
- # This `Deferred` already has a result.
- # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
- # on. This blocks the coroutine that did this `await`.
+ # This `Deferred` already has a result. We chain a new,
+ # unresolved, `Deferred` to the end of this Deferred that it
+ # will wait on. This blocks the coroutine that did this `await`.
# We queue it up for unblocking later.
new_deferred: "Deferred[T]" = Deferred()
self._to_unblock[new_deferred] = deferred.result
+ deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred))
+
_log_await_stack(
stack,
self._previous_stack,
@@ -411,7 +465,9 @@ class Deferred__next__Patch:
)
self._previous_stack = stack
- return make_deferred_yieldable(new_deferred)
+ # Continue iterating on the deferred now that we've blocked it
+ # again.
+ return
# This `Deferred` does not have a result yet.
# The `await` will block normally, so we don't have to do anything.
@@ -423,9 +479,9 @@ class Deferred__next__Patch:
)
self._previous_stack = stack
- return self._original_Deferred___next__(deferred, value)
+ return
- return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+ return mock.patch.object(Deferred, "__await__", new=Deferred___await__)
def unblock_awaits(self) -> None:
"""Unblocks any shared processing that we forced to block.
@@ -433,6 +489,9 @@ class Deferred__next__Patch:
Must be called when done, otherwise processing shared between multiple requests,
such as database queries started by `@cached`, will become permanently stuck.
"""
+ # Also disable blocking at future await points
+ self._block_new_awaits = False
+
to_unblock = self._to_unblock
self._to_unblock = {}
for deferred, result in to_unblock.items():
|