diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index ab89cab812..cce8d595fc 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import traceback
+
from twisted.internet import defer
-from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
+from twisted.python.failure import Failure
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -21,7 +24,11 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
-from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.async_helpers import (
+ ObservableDeferred,
+ concurrently_execute,
+ timeout_deferred,
+)
from tests.unittest import TestCase
@@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
+
+
+class _TestException(Exception):
+ pass
+
+
+class ConcurrentlyExecuteTest(TestCase):
+ def test_limits_runners(self):
+ """If we have more tasks than runners, we should get the limit of runners"""
+ started = 0
+ waiters = []
+ processed = []
+
+ async def callback(v):
+ # when we first enter, bump the start count
+ nonlocal started
+ started += 1
+
+ # record the fact we got an item
+ processed.append(v)
+
+ # wait for the goahead before returning
+ d2 = Deferred()
+ waiters.append(d2)
+ await d2
+
+ # set it going
+ d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
+
+ # check we got exactly 3 processes
+ self.assertEqual(started, 3)
+ self.assertEqual(len(waiters), 3)
+
+ # let one finish
+ waiters.pop().callback(0)
+
+ # ... which should start another
+ self.assertEqual(started, 4)
+ self.assertEqual(len(waiters), 3)
+
+ # we still shouldn't be done
+ self.assertNoResult(d2)
+
+ # finish the job
+ while waiters:
+ waiters.pop().callback(0)
+
+ # check everything got done
+ self.assertEqual(started, 5)
+ self.assertCountEqual(processed, [1, 2, 3, 4, 5])
+ self.successResultOf(d2)
+
+ def test_preserves_stacktraces(self):
+ """Test that the stacktrace from an exception thrown in the callback is preserved"""
+ d1 = Deferred()
+
+ async def callback(v):
+ # alas, this doesn't work at all without an await here
+ await d1
+ raise _TestException("bah")
+
+ async def caller():
+ try:
+ await concurrently_execute(callback, [1], 2)
+ except _TestException as e:
+ tb = traceback.extract_tb(e.__traceback__)
+ # we expect to see "caller", "concurrently_execute" and "callback".
+ self.assertEqual(tb[0].name, "caller")
+ self.assertEqual(tb[1].name, "concurrently_execute")
+ self.assertEqual(tb[-1].name, "callback")
+ else:
+ self.fail("No exception thrown")
+
+ d2 = ensureDeferred(caller())
+ d1.callback(0)
+ self.successResultOf(d2)
+
+ def test_preserves_stacktraces_on_preformed_failure(self):
+ """Test that the stacktrace on a Failure returned by the callback is preserved"""
+ d1 = Deferred()
+ f = Failure(_TestException("bah"))
+
+ async def callback(v):
+ # alas, this doesn't work at all without an await here
+ await d1
+ await defer.fail(f)
+
+ async def caller():
+ try:
+ await concurrently_execute(callback, [1], 2)
+ except _TestException as e:
+ tb = traceback.extract_tb(e.__traceback__)
+ # we expect to see "caller", "concurrently_execute", "callback",
+ # and some magic from inside ensureDeferred that happens when .fail
+ # is called.
+ self.assertEqual(tb[0].name, "caller")
+ self.assertEqual(tb[1].name, "concurrently_execute")
+ self.assertEqual(tb[-2].name, "callback")
+ else:
+ self.fail("No exception thrown")
+
+ d2 = ensureDeferred(caller())
+ d1.callback(0)
+ self.successResultOf(d2)
|