summary refs log tree commit diff
path: root/tests/util/test_async_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_async_helpers.py')
-rw-r--r--tests/util/test_async_helpers.py115
1 files changed, 113 insertions, 2 deletions
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index ab89cab812..cce8d595fc 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -11,9 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import traceback
+
 from twisted.internet import defer
-from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.internet.task import Clock
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -21,7 +24,11 @@ from synapse.logging.context import (
     PreserveLoggingContext,
     current_context,
 )
-from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.async_helpers import (
+    ObservableDeferred,
+    concurrently_execute,
+    timeout_deferred,
+)
 
 from tests.unittest import TestCase
 
@@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
             )
             self.failureResultOf(timing_out_d, defer.TimeoutError)
             self.assertIs(current_context(), context_one)
+
+
+class _TestException(Exception):
+    pass
+
+
+class ConcurrentlyExecuteTest(TestCase):
+    def test_limits_runners(self):
+        """If we have more tasks than runners, we should get the limit of runners"""
+        started = 0
+        waiters = []
+        processed = []
+
+        async def callback(v):
+            # when we first enter, bump the start count
+            nonlocal started
+            started += 1
+
+            # record the fact we got an item
+            processed.append(v)
+
+            # wait for the goahead before returning
+            d2 = Deferred()
+            waiters.append(d2)
+            await d2
+
+        # set it going
+        d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
+
+        # check we got exactly 3 processes
+        self.assertEqual(started, 3)
+        self.assertEqual(len(waiters), 3)
+
+        # let one finish
+        waiters.pop().callback(0)
+
+        # ... which should start another
+        self.assertEqual(started, 4)
+        self.assertEqual(len(waiters), 3)
+
+        # we still shouldn't be done
+        self.assertNoResult(d2)
+
+        # finish the job
+        while waiters:
+            waiters.pop().callback(0)
+
+        # check everything got done
+        self.assertEqual(started, 5)
+        self.assertCountEqual(processed, [1, 2, 3, 4, 5])
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces(self):
+        """Test that the stacktrace from an exception thrown in the callback is preserved"""
+        d1 = Deferred()
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            raise _TestException("bah")
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute" and "callback".
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-1].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces_on_preformed_failure(self):
+        """Test that the stacktrace on a Failure returned by the callback is preserved"""
+        d1 = Deferred()
+        f = Failure(_TestException("bah"))
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            await defer.fail(f)
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute", "callback",
+                # and some magic from inside ensureDeferred that happens when .fail
+                # is called.
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-2].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)