diff options
Diffstat (limited to 'tests/util')
-rw-r--r-- | tests/util/test_async_helpers.py | 115 |
1 files changed, 113 insertions, 2 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ab89cab812..cce8d595fc 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import traceback + from twisted.internet import defer -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock +from twisted.python.failure import Failure from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -21,7 +24,11 @@ from synapse.logging.context import ( PreserveLoggingContext, current_context, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred +from synapse.util.async_helpers import ( + ObservableDeferred, + concurrently_execute, + timeout_deferred, +) from tests.unittest import TestCase @@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase): ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) + + +class _TestException(Exception): + pass + + +class ConcurrentlyExecuteTest(TestCase): + def test_limits_runners(self): + """If we have more tasks than runners, we should get the limit of runners""" + started = 0 + waiters = [] + processed = [] + + async def callback(v): + # when we first enter, bump the start count + nonlocal started + started += 1 + + # record the fact we got an item + processed.append(v) + + # wait for the goahead before returning + d2 = Deferred() + waiters.append(d2) + await d2 + + # set it going + d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3)) + + # check we got exactly 3 processes + self.assertEqual(started, 3) + self.assertEqual(len(waiters), 3) + + # let one finish + waiters.pop().callback(0) + + # ... which should start another + self.assertEqual(started, 4) + self.assertEqual(len(waiters), 3) + + # we still shouldn't be done + self.assertNoResult(d2) + + # finish the job + while waiters: + waiters.pop().callback(0) + + # check everything got done + self.assertEqual(started, 5) + self.assertCountEqual(processed, [1, 2, 3, 4, 5]) + self.successResultOf(d2) + + def test_preserves_stacktraces(self): + """Test that the stacktrace from an exception thrown in the callback is preserved""" + d1 = Deferred() + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + raise _TestException("bah") + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute" and "callback". + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-1].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) + + def test_preserves_stacktraces_on_preformed_failure(self): + """Test that the stacktrace on a Failure returned by the callback is preserved""" + d1 = Deferred() + f = Failure(_TestException("bah")) + + async def callback(v): + # alas, this doesn't work at all without an await here + await d1 + await defer.fail(f) + + async def caller(): + try: + await concurrently_execute(callback, [1], 2) + except _TestException as e: + tb = traceback.extract_tb(e.__traceback__) + # we expect to see "caller", "concurrently_execute", "callback", + # and some magic from inside ensureDeferred that happens when .fail + # is called. + self.assertEqual(tb[0].name, "caller") + self.assertEqual(tb[1].name, "concurrently_execute") + self.assertEqual(tb[-2].name, "callback") + else: + self.fail("No exception thrown") + + d2 = ensureDeferred(caller()) + d1.callback(0) + self.successResultOf(d2) |