diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index d24c4f68c4..01e3cd46f6 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -205,7 +205,10 @@ T = TypeVar("T")
async def concurrently_execute(
- func: Callable[[T], Any], args: Iterable[T], limit: int
+ func: Callable[[T], Any],
+ args: Iterable[T],
+ limit: int,
+ delay_cancellation: bool = False,
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@@ -215,6 +218,8 @@ async def concurrently_execute(
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit: Maximum number of conccurent executions.
+ delay_cancellation: Whether to delay cancellation until after the invocations
+ have finished.
Returns:
None, when all function invocations have finished. The return values
@@ -233,9 +238,16 @@ async def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
- await yieldable_gather_results(
- _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
- )
+ if delay_cancellation:
+ await yieldable_gather_results_delaying_cancellation(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
+ else:
+ await yieldable_gather_results(
+ _concurrently_execute_inner,
+ (value for value in itertools.islice(it, limit)),
+ )
P = ParamSpec("P")
@@ -292,6 +304,41 @@ async def yieldable_gather_results(
raise dfe.subFailure.value from None
+async def yieldable_gather_results_delaying_cancellation(
+ func: Callable[Concatenate[T, P], Awaitable[R]],
+ iter: Iterable[T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+) -> List[R]:
+ """Executes the function with each argument concurrently.
+ Cancellation is delayed until after all the results have been gathered.
+
+ See `yieldable_gather_results`.
+
+ Args:
+ func: Function to execute that returns a Deferred
+ iter: An iterable that yields items that get passed as the first
+ argument to the function
+ *args: Arguments to be passed to each call to func
+ **kwargs: Keyword arguments to be passed to each call to func
+
+ Returns
+ A list containing the results of the function
+ """
+ try:
+ return await make_deferred_yieldable(
+ delay_cancellation(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
+ consumeErrors=True,
+ )
+ )
+ )
+ except defer.FirstError as dfe:
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
|