summary refs log tree commit diff
path: root/synapse/util/async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async_helpers.py')
-rw-r--r--synapse/util/async_helpers.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5c55bb0125..061102c3c8 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,6 +15,7 @@
 
 import collections
 import inspect
+import itertools
 import logging
 from contextlib import contextmanager
 from typing import (
@@ -160,8 +161,11 @@ class ObservableDeferred:
         )
 
 
+T = TypeVar("T")
+
+
 def concurrently_execute(
-    func: Callable, args: Iterable[Any], limit: int
+    func: Callable[[T], Any], args: Iterable[T], limit: int
 ) -> defer.Deferred:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
@@ -173,20 +177,27 @@ def concurrently_execute(
         limit: Maximum number of conccurent executions.
 
     Returns:
-        Deferred[list]: Resolved when all function invocations have finished.
+        Deferred: Resolved when all function invocations have finished.
     """
     it = iter(args)
 
-    async def _concurrently_execute_inner():
+    async def _concurrently_execute_inner(value: T) -> None:
         try:
             while True:
-                await maybe_awaitable(func(next(it)))
+                await maybe_awaitable(func(value))
+                value = next(it)
         except StopIteration:
             pass
 
+    # We use `itertools.islice` to handle the case where the number of args is
+    # less than the limit, avoiding needlessly spawning unnecessary background
+    # tasks.
     return make_deferred_yieldable(
         defer.gatherResults(
-            [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+            [
+                run_in_background(_concurrently_execute_inner, value)
+                for value in itertools.islice(it, limit)
+            ],
             consumeErrors=True,
         )
     ).addErrback(unwrapFirstError)