summary refs log tree commit diff
path: root/tests/util/test_async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_async_helpers.py')
-rw-r--r--tests/util/test_async_helpers.py108
1 files changed, 107 insertions, 1 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py

index d82822d00d..350a2b7c8c 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py
@@ -18,7 +18,7 @@ # # import traceback -from typing import Generator, List, NoReturn, Optional +from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar from parameterized import parameterized_class @@ -39,6 +39,7 @@ from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, delay_cancellation, + gather_optional_coroutines, stop_cancellation, timeout_deferred, ) @@ -46,6 +47,8 @@ from synapse.util.async_helpers import ( from tests.server import get_clock from tests.unittest import TestCase +T = TypeVar("T") + class ObservableDeferredTest(TestCase): def test_succeed(self) -> None: @@ -588,3 +591,106 @@ class AwakenableSleeperTests(TestCase): sleeper.wake("name") self.assertTrue(d1.called) self.assertTrue(d2.called) + + +class GatherCoroutineTests(TestCase): + """Tests for `gather_optional_coroutines`""" + + def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]: + """Returns a coroutine and a deferred that it is waiting on to resolve""" + + d: "defer.Deferred[T]" = defer.Deferred() + + async def inner() -> T: + with PreserveLoggingContext(): + return await d + + return inner(), d + + def test_single(self) -> None: + "Test passing in a single coroutine works" + + with LoggingContext("test_ctx") as text_ctx: + deferred: "defer.Deferred[None]" + coroutine, deferred = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Resolving the deferred will resolve the coroutine + deferred.callback(None) + + # All coroutines have resolved, and so we should have the results + result = self.successResultOf(gather_deferred) + self.assertEqual(result, (None,)) + + # We should be back in the normal context. + self.assertEqual(current_context(), text_ctx) + + def test_multiple_resolve(self) -> None: + "Test passing in multiple coroutine that all resolve works" + + with LoggingContext("test_ctx") as test_ctx: + deferred1: "defer.Deferred[int]" + coroutine1, deferred1 = self.make_coroutine() + deferred2: "defer.Deferred[str]" + coroutine2, deferred2 = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine1, coroutine2) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Even if we resolve one of the coroutines, we shouldn't have a result + # yet + deferred2.callback("test") + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + deferred1.callback(1) + + # All coroutines have resolved, and so we should have the results + result = self.successResultOf(gather_deferred) + self.assertEqual(result, (1, "test")) + + # We should be back in the normal context. + self.assertEqual(current_context(), test_ctx) + + def test_multiple_fail(self) -> None: + "Test passing in multiple coroutine where one fails does the right thing" + + with LoggingContext("test_ctx") as test_ctx: + deferred1: "defer.Deferred[int]" + coroutine1, deferred1 = self.make_coroutine() + deferred2: "defer.Deferred[str]" + coroutine2, deferred2 = self.make_coroutine() + + gather_deferred = defer.ensureDeferred( + gather_optional_coroutines(coroutine1, coroutine2) + ) + + # We shouldn't have a result yet, and should be in the sentinel + # context. + self.assertNoResult(gather_deferred) + self.assertEqual(current_context(), SENTINEL_CONTEXT) + + # Throw an exception in one of the coroutines + exc = Exception("test") + deferred2.errback(exc) + + # Expect the gather deferred to immediately fail + result_exc = self.failureResultOf(gather_deferred) + self.assertEqual(result_exc.value, exc) + + # We should be back in the normal context. + self.assertEqual(current_context(), test_ctx)