summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py54
1 files changed, 32 insertions, 22 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 3f7299aff7..a83296a229 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -29,6 +29,7 @@ from typing import (
     Hashable,
     Iterable,
     Iterator,
+    List,
     Optional,
     Set,
     Tuple,
@@ -51,7 +52,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
-from synapse.util import Clock, unwrapFirstError
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
 T = TypeVar("T")
 
 
-def concurrently_execute(
+async def concurrently_execute(
     func: Callable[[T], Any], args: Iterable[T], limit: int
-) -> defer.Deferred:
+) -> None:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
@@ -221,20 +222,14 @@ def concurrently_execute(
     # We use `itertools.islice` to handle the case where the number of args is
     # less than the limit, avoiding needlessly spawning unnecessary background
     # tasks.
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [
-                run_in_background(_concurrently_execute_inner, value)
-                for value in itertools.islice(it, limit)
-            ],
-            consumeErrors=True,
-        )
-    ).addErrback(unwrapFirstError)
+    await yieldable_gather_results(
+        _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
+    )
 
 
-def yieldable_gather_results(
-    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
-) -> defer.Deferred:
+async def yieldable_gather_results(
+    func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
+) -> List[T]:
     """Executes the function with each argument concurrently.
 
     Args:
@@ -245,15 +240,30 @@ def yieldable_gather_results(
         **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
-        Deferred[list]: Resolved when all functions have been invoked, or errors if
-        one of the function calls fails.
+        A list containing the results of the function
     """
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [run_in_background(func, item, *args, **kwargs) for item in iter],
-            consumeErrors=True,
+    try:
+        return await make_deferred_yieldable(
+            defer.gatherResults(
+                [run_in_background(func, item, *args, **kwargs) for item in iter],
+                consumeErrors=True,
+            )
         )
-    ).addErrback(unwrapFirstError)
+    except defer.FirstError as dfe:
+        # unwrap the error from defer.gatherResults.
+
+        # The raised exception's traceback only includes func() etc if
+        # the 'await' happens before the exception is thrown - ie if the failure
+        # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+        # could be large.
+        #
+        # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+        # we could throw Twisted into the fires of Mordor.
+
+        # suppress exception chaining, because the FirstError doesn't tell us anything
+        # very interesting.
+        assert isinstance(dfe.subFailure.value, BaseException)
+        raise dfe.subFailure.value from None
 
 
 T1 = TypeVar("T1")