summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorreivilibre <oliverw@matrix.org>2023-01-10 11:17:59 +0000
committerGitHub <noreply@github.com>2023-01-10 11:17:59 +0000
commitba4ea7d13ffae53644b206222af95a5171faa27c (patch)
tree7867aabc7a90d7ad1b539c015db7115d50af1d8c /synapse/util
parentAdd missing worker settings to shared configuration (#14748) (diff)
downloadsynapse-ba4ea7d13ffae53644b206222af95a5171faa27c.tar.xz
Batch up replication requests to request the resyncing of remote users's devices. (#14716)
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py55
1 files changed, 51 insertions, 4 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index d24c4f68c4..01e3cd46f6 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -205,7 +205,10 @@ T = TypeVar("T")
 
 
 async def concurrently_execute(
-    func: Callable[[T], Any], args: Iterable[T], limit: int
+    func: Callable[[T], Any],
+    args: Iterable[T],
+    limit: int,
+    delay_cancellation: bool = False,
 ) -> None:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
@@ -215,6 +218,8 @@ async def concurrently_execute(
         args: List of arguments to pass to func, each invocation of func
             gets a single argument.
         limit: Maximum number of conccurent executions.
+        delay_cancellation: Whether to delay cancellation until after the invocations
+            have finished.
 
     Returns:
         None, when all function invocations have finished. The return values
@@ -233,9 +238,16 @@ async def concurrently_execute(
     # We use `itertools.islice` to handle the case where the number of args is
     # less than the limit, avoiding needlessly spawning unnecessary background
     # tasks.
-    await yieldable_gather_results(
-        _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
-    )
+    if delay_cancellation:
+        await yieldable_gather_results_delaying_cancellation(
+            _concurrently_execute_inner,
+            (value for value in itertools.islice(it, limit)),
+        )
+    else:
+        await yieldable_gather_results(
+            _concurrently_execute_inner,
+            (value for value in itertools.islice(it, limit)),
+        )
 
 
 P = ParamSpec("P")
@@ -292,6 +304,41 @@ async def yieldable_gather_results(
         raise dfe.subFailure.value from None
 
 
+async def yieldable_gather_results_delaying_cancellation(
+    func: Callable[Concatenate[T, P], Awaitable[R]],
+    iter: Iterable[T],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> List[R]:
+    """Executes the function with each argument concurrently.
+    Cancellation is delayed until after all the results have been gathered.
+
+    See `yieldable_gather_results`.
+
+    Args:
+        func: Function to execute that returns a Deferred
+        iter: An iterable that yields items that get passed as the first
+            argument to the function
+        *args: Arguments to be passed to each call to func
+        **kwargs: Keyword arguments to be passed to each call to func
+
+    Returns
+        A list containing the results of the function
+    """
+    try:
+        return await make_deferred_yieldable(
+            delay_cancellation(
+                defer.gatherResults(
+                    [run_in_background(func, item, *args, **kwargs) for item in iter],  # type: ignore[arg-type]
+                    consumeErrors=True,
+                )
+            )
+        )
+    except defer.FirstError as dfe:
+        assert isinstance(dfe.subFailure.value, BaseException)
+        raise dfe.subFailure.value from None
+
+
 T1 = TypeVar("T1")
 T2 = TypeVar("T2")
 T3 = TypeVar("T3")