Fail test cases if they fail to await all awaitables (#8690)
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
"""
Utilities for running the unit tests
"""
+import sys
+import warnings
from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
TV = TypeVar("TV")
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
future = Future() # type: ignore
future.set_result(result)
return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+ """
+ Convert warnings from a non-awaited coroutines into errors.
+ """
+ warnings.simplefilter("error", RuntimeWarning)
+
+ # unraisablehook was added in Python 3.8.
+ if not hasattr(sys, "unraisablehook"):
+ return lambda: None
+
+ # State shared between unraisablehook and check_for_unraisable_exceptions.
+ unraisable_exceptions = []
+ orig_unraisablehook = sys.unraisablehook # type: ignore
+
+ def unraisablehook(unraisable):
+ unraisable_exceptions.append(unraisable.exc_value)
+
+ def cleanup():
+ """
+ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+ """
+ sys.unraisablehook = orig_unraisablehook # type: ignore
+ if unraisable_exceptions:
+ raise unraisable_exceptions.pop()
+
+ sys.unraisablehook = unraisablehook # type: ignore
+
+ return cleanup
|