diff --git a/changelog.d/12109.misc b/changelog.d/12109.misc
new file mode 100644
index 0000000000..3295e49f43
--- /dev/null
+++ b/changelog.d/12109.misc
@@ -0,0 +1 @@
+Improve exception handling for concurrent execution.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a9c964cd75..ce1fa3c78e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
+from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 511f52534b..58b4220ff3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure: Failure) -> Failure:
- # defer.gatherResults and DeferredLists wrap failures.
+ # Deprecated: you probably just want to catch defer.FirstError and reraise
+ # the subFailure's value, which will do a better job of preserving stacktraces.
+ # (actually, you probably want to use yieldable_gather_results anyway)
failure.trap(defer.FirstError)
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 3f7299aff7..a83296a229 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -29,6 +29,7 @@ from typing import (
Hashable,
Iterable,
Iterator,
+ List,
Optional,
Set,
Tuple,
@@ -51,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
-from synapse.util import Clock, unwrapFirstError
+from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
T = TypeVar("T")
-def concurrently_execute(
+async def concurrently_execute(
func: Callable[[T], Any], args: Iterable[T], limit: int
-) -> defer.Deferred:
+) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@@ -221,20 +222,14 @@ def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
- return make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(_concurrently_execute_inner, value)
- for value in itertools.islice(it, limit)
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
+ await yieldable_gather_results(
+ _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
+ )
-def yieldable_gather_results(
- func: Callable, iter: Iterable, *args: Any, **kwargs: Any
-) -> defer.Deferred:
+async def yieldable_gather_results(
+ func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
+) -> List[T]:
"""Executes the function with each argument concurrently.
Args:
@@ -245,15 +240,30 @@ def yieldable_gather_results(
**kwargs: Keyword arguments to be passed to each call to func
Returns
- Deferred[list]: Resolved when all functions have been invoked, or errors if
- one of the function calls fails.
+ A list containing the results of the function
"""
- return make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(func, item, *args, **kwargs) for item in iter],
- consumeErrors=True,
+ try:
+ return await make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter],
+ consumeErrors=True,
+ )
)
- ).addErrback(unwrapFirstError)
+ except defer.FirstError as dfe:
+ # unwrap the error from defer.gatherResults.
+
+ # The raised exception's traceback only includes func() etc if
+ # the 'await' happens before the exception is thrown - ie if the failure
+ # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+ # could be large.
+ #
+ # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+ # we could throw Twisted into the fires of Mordor.
+
+ # suppress exception chaining, because the FirstError doesn't tell us anything
+ # very interesting.
+ assert isinstance(dfe.subFailure.value, BaseException)
+ raise dfe.subFailure.value from None
T1 = TypeVar("T1")
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index ab89cab812..cce8d595fc 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import traceback
+
from twisted.internet import defer
-from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
+from twisted.python.failure import Failure
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -21,7 +24,11 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
-from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.async_helpers import (
+ ObservableDeferred,
+ concurrently_execute,
+ timeout_deferred,
+)
from tests.unittest import TestCase
@@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
+
+
+class _TestException(Exception):
+ pass
+
+
+class ConcurrentlyExecuteTest(TestCase):
+ def test_limits_runners(self):
+ """If we have more tasks than runners, we should get the limit of runners"""
+ started = 0
+ waiters = []
+ processed = []
+
+ async def callback(v):
+ # when we first enter, bump the start count
+ nonlocal started
+ started += 1
+
+ # record the fact we got an item
+ processed.append(v)
+
+ # wait for the goahead before returning
+ d2 = Deferred()
+ waiters.append(d2)
+ await d2
+
+ # set it going
+ d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
+
+ # check we got exactly 3 processes
+ self.assertEqual(started, 3)
+ self.assertEqual(len(waiters), 3)
+
+ # let one finish
+ waiters.pop().callback(0)
+
+ # ... which should start another
+ self.assertEqual(started, 4)
+ self.assertEqual(len(waiters), 3)
+
+ # we still shouldn't be done
+ self.assertNoResult(d2)
+
+ # finish the job
+ while waiters:
+ waiters.pop().callback(0)
+
+ # check everything got done
+ self.assertEqual(started, 5)
+ self.assertCountEqual(processed, [1, 2, 3, 4, 5])
+ self.successResultOf(d2)
+
+ def test_preserves_stacktraces(self):
+ """Test that the stacktrace from an exception thrown in the callback is preserved"""
+ d1 = Deferred()
+
+ async def callback(v):
+ # alas, this doesn't work at all without an await here
+ await d1
+ raise _TestException("bah")
+
+ async def caller():
+ try:
+ await concurrently_execute(callback, [1], 2)
+ except _TestException as e:
+ tb = traceback.extract_tb(e.__traceback__)
+ # we expect to see "caller", "concurrently_execute" and "callback".
+ self.assertEqual(tb[0].name, "caller")
+ self.assertEqual(tb[1].name, "concurrently_execute")
+ self.assertEqual(tb[-1].name, "callback")
+ else:
+ self.fail("No exception thrown")
+
+ d2 = ensureDeferred(caller())
+ d1.callback(0)
+ self.successResultOf(d2)
+
+ def test_preserves_stacktraces_on_preformed_failure(self):
+ """Test that the stacktrace on a Failure returned by the callback is preserved"""
+ d1 = Deferred()
+ f = Failure(_TestException("bah"))
+
+ async def callback(v):
+ # alas, this doesn't work at all without an await here
+ await d1
+ await defer.fail(f)
+
+ async def caller():
+ try:
+ await concurrently_execute(callback, [1], 2)
+ except _TestException as e:
+ tb = traceback.extract_tb(e.__traceback__)
+ # we expect to see "caller", "concurrently_execute", "callback",
+ # and some magic from inside ensureDeferred that happens when .fail
+ # is called.
+ self.assertEqual(tb[0].name, "caller")
+ self.assertEqual(tb[1].name, "concurrently_execute")
+ self.assertEqual(tb[-2].name, "callback")
+ else:
+ self.fail("No exception thrown")
+
+ d2 = ensureDeferred(caller())
+ d1.callback(0)
+ self.successResultOf(d2)
|