summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12109.misc1
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/util/__init__.py4
-rw-r--r--synapse/util/async_helpers.py54
-rw-r--r--tests/util/test_async_helpers.py115
5 files changed, 151 insertions, 27 deletions
diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc
new file mode 100644
index 0000000000..3295e49f43
--- /dev/null
+++ b/changelog.d/12109.misc
@@ -0,0 +1 @@
+Improve exception handling for concurrent execution.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a9c964cd75..ce1fa3c78e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
+from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 511f52534b..58b4220ff3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
 def unwrapFirstError(failure: Failure) -> Failure:
-    # defer.gatherResults and DeferredLists wrap failures.
+    # Deprecated: you probably just want to catch defer.FirstError and reraise
+    # the subFailure's value, which will do a better job of preserving stacktraces.
+    # (actually, you probably want to use yieldable_gather_results anyway)
     failure.trap(defer.FirstError)
     return failure.value.subFailure  # type: ignore[union-attr]  # Issue in Twisted's annotations
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 3f7299aff7..a83296a229 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -29,6 +29,7 @@ from typing import (
     Hashable,
     Iterable,
     Iterator,
+    List,
     Optional,
     Set,
     Tuple,
@@ -51,7 +52,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
-from synapse.util import Clock, unwrapFirstError
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
 T = TypeVar("T")
 
 
-def concurrently_execute(
+async def concurrently_execute(
     func: Callable[[T], Any], args: Iterable[T], limit: int
-) -> defer.Deferred:
+) -> None:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
@@ -221,20 +222,14 @@ def concurrently_execute(
     # We use `itertools.islice` to handle the case where the number of args is
     # less than the limit, avoiding needlessly spawning unnecessary background
     # tasks.
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [
-                run_in_background(_concurrently_execute_inner, value)
-                for value in itertools.islice(it, limit)
-            ],
-            consumeErrors=True,
-        )
-    ).addErrback(unwrapFirstError)
+    await yieldable_gather_results(
+        _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
+    )
 
 
-def yieldable_gather_results(
-    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
-) -> defer.Deferred:
+async def yieldable_gather_results(
+    func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
+) -> List[T]:
     """Executes the function with each argument concurrently.
 
     Args:
@@ -245,15 +240,30 @@ def yieldable_gather_results(
         **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
-        Deferred[list]: Resolved when all functions have been invoked, or errors if
-        one of the function calls fails.
+        A list containing the results of the function
     """
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [run_in_background(func, item, *args, **kwargs) for item in iter],
-            consumeErrors=True,
+    try:
+        return await make_deferred_yieldable(
+            defer.gatherResults(
+                [run_in_background(func, item, *args, **kwargs) for item in iter],
+                consumeErrors=True,
+            )
         )
-    ).addErrback(unwrapFirstError)
+    except defer.FirstError as dfe:
+        # unwrap the error from defer.gatherResults.
+
+        # The raised exception's traceback only includes func() etc if
+        # the 'await' happens before the exception is thrown - ie if the failure
+        # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+        # could be large.
+        #
+        # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+        # we could throw Twisted into the fires of Mordor.
+
+        # suppress exception chaining, because the FirstError doesn't tell us anything
+        # very interesting.
+        assert isinstance(dfe.subFailure.value, BaseException)
+        raise dfe.subFailure.value from None
 
 
 T1 = TypeVar("T1")
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index ab89cab812..cce8d595fc 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -11,9 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import traceback
+
 from twisted.internet import defer
-from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.internet.task import Clock
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -21,7 +24,11 @@ from synapse.logging.context import (
     PreserveLoggingContext,
     current_context,
 )
-from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.async_helpers import (
+    ObservableDeferred,
+    concurrently_execute,
+    timeout_deferred,
+)
 
 from tests.unittest import TestCase
 
@@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
             )
             self.failureResultOf(timing_out_d, defer.TimeoutError)
             self.assertIs(current_context(), context_one)
+
+
+class _TestException(Exception):
+    pass
+
+
+class ConcurrentlyExecuteTest(TestCase):
+    def test_limits_runners(self):
+        """If we have more tasks than runners, we should get the limit of runners"""
+        started = 0
+        waiters = []
+        processed = []
+
+        async def callback(v):
+            # when we first enter, bump the start count
+            nonlocal started
+            started += 1
+
+            # record the fact we got an item
+            processed.append(v)
+
+            # wait for the goahead before returning
+            d2 = Deferred()
+            waiters.append(d2)
+            await d2
+
+        # set it going
+        d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
+
+        # check we got exactly 3 processes
+        self.assertEqual(started, 3)
+        self.assertEqual(len(waiters), 3)
+
+        # let one finish
+        waiters.pop().callback(0)
+
+        # ... which should start another
+        self.assertEqual(started, 4)
+        self.assertEqual(len(waiters), 3)
+
+        # we still shouldn't be done
+        self.assertNoResult(d2)
+
+        # finish the job
+        while waiters:
+            waiters.pop().callback(0)
+
+        # check everything got done
+        self.assertEqual(started, 5)
+        self.assertCountEqual(processed, [1, 2, 3, 4, 5])
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces(self):
+        """Test that the stacktrace from an exception thrown in the callback is preserved"""
+        d1 = Deferred()
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            raise _TestException("bah")
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute" and "callback".
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-1].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces_on_preformed_failure(self):
+        """Test that the stacktrace on a Failure returned by the callback is preserved"""
+        d1 = Deferred()
+        f = Failure(_TestException("bah"))
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            await defer.fail(f)
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute", "callback",
+                # and some magic from inside ensureDeferred that happens when .fail
+                # is called.
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-2].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)