diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index d82822d00d..350a2b7c8c 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -18,7 +18,7 @@
#
#
import traceback
-from typing import Generator, List, NoReturn, Optional
+from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
@@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
+ gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
@@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
from tests.server import get_clock
from tests.unittest import TestCase
+T = TypeVar("T")
+
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
@@ -588,3 +591,106 @@ class AwakenableSleeperTests(TestCase):
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
+
+
+class GatherCoroutineTests(TestCase):
+ """Tests for `gather_optional_coroutines`"""
+
+ def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
+ """Returns a coroutine and a deferred that it is waiting on to resolve"""
+
+ d: "defer.Deferred[T]" = defer.Deferred()
+
+ async def inner() -> T:
+ with PreserveLoggingContext():
+ return await d
+
+ return inner(), d
+
+ def test_single(self) -> None:
+ "Test passing in a single coroutine works"
+
+ with LoggingContext("test_ctx") as text_ctx:
+ deferred: "defer.Deferred[None]"
+ coroutine, deferred = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Resolving the deferred will resolve the coroutine
+ deferred.callback(None)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (None,))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), text_ctx)
+
+ def test_multiple_resolve(self) -> None:
+ "Test passing in multiple coroutine that all resolve works"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Even if we resolve one of the coroutines, we shouldn't have a result
+ # yet
+ deferred2.callback("test")
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ deferred1.callback(1)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (1, "test"))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
+
+ def test_multiple_fail(self) -> None:
+ "Test passing in multiple coroutine where one fails does the right thing"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Throw an exception in one of the coroutines
+ exc = Exception("test")
+ deferred2.errback(exc)
+
+ # Expect the gather deferred to immediately fail
+ result_exc = self.failureResultOf(gather_deferred)
+ self.assertEqual(result_exc.value, exc)
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
|