diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 8618bb0651..e1eb8a4863 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -51,7 +51,7 @@ from typing import (
)
import attr
-from typing_extensions import Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, Literal, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -61,6 +61,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
+ run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
@@ -344,6 +345,7 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
+T5 = TypeVar("T5")
@overload
@@ -402,6 +404,112 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
+) -> Tuple[Optional[T1]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
+
+
+@overload
+async def gather_optional_coroutines(
+ *coroutines: Unpack[
+ Tuple[
+ Optional[Coroutine[Any, Any, T1]],
+ Optional[Coroutine[Any, Any, T2]],
+ Optional[Coroutine[Any, Any, T3]],
+ Optional[Coroutine[Any, Any, T4]],
+ Optional[Coroutine[Any, Any, T5]],
+ ]
+ ],
+) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
+
+
+async def gather_optional_coroutines(
+ *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
+) -> Tuple[Optional[T1], ...]:
+ """Helper function that allows waiting on multiple coroutines at once.
+
+ The return value is a tuple of the return values of the coroutines in order.
+
+ If a `None` is passed instead of a coroutine, it will be ignored and a None
+ is returned in the tuple.
+
+ Note: For typechecking we need to have an explicit overload for each
+ distinct number of coroutines passed in. If you see type problems, it's
+ likely because you're using many arguments and you need to add a new
+ overload above.
+ """
+
+ try:
+ results = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_coroutine_in_background(coroutine)
+ for coroutine in coroutines
+ if coroutine is not None
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ results_iter = iter(results)
+ return tuple(
+ next(results_iter) if coroutine is not None else None
+ for coroutine in coroutines
+ )
+ except defer.FirstError as dfe:
+ # unwrap the error from defer.gatherResults.
+
+ # The raised exception's traceback only includes func() etc if
+ # the 'await' happens before the exception is thrown - ie if the failure
+ # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+ # could be large.
+ #
+ # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+ # we could throw Twisted into the fires of Mordor.
+
+ # suppress exception chaining, because the FirstError doesn't tell us anything
+ # very interesting.
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
+
+
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
|