summary refs log tree commit diff
path: root/tests/test_utils/__init__.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-10-30 07:15:07 -0400
committerGitHub <noreply@github.com>2020-10-30 07:15:07 -0400
commitfd7c7434457e215d73873748604f430c52586498 (patch)
tree802dadad7c09fa86789b7392de90745d45da4186 /tests/test_utils/__init__.py
parentFix race for concurrent downloads of remote media. (#8682) (diff)
downloadsynapse-fd7c7434457e215d73873748604f430c52586498.tar.xz
Fail test cases if they fail to await all awaitables (#8690)
Diffstat (limited to 'tests/test_utils/__init__.py')
-rw-r--r--tests/test_utils/__init__.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
 """
 Utilities for running the unit tests
 """
+import sys
+import warnings
 from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
     future = Future()  # type: ignore
     future.set_result(result)
     return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+    """
+    Convert warnings from a non-awaited coroutines into errors.
+    """
+    warnings.simplefilter("error", RuntimeWarning)
+
+    # unraisablehook was added in Python 3.8.
+    if not hasattr(sys, "unraisablehook"):
+        return lambda: None
+
+    # State shared between unraisablehook and check_for_unraisable_exceptions.
+    unraisable_exceptions = []
+    orig_unraisablehook = sys.unraisablehook  # type: ignore
+
+    def unraisablehook(unraisable):
+        unraisable_exceptions.append(unraisable.exc_value)
+
+    def cleanup():
+        """
+        A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+        """
+        sys.unraisablehook = orig_unraisablehook  # type: ignore
+        if unraisable_exceptions:
+            raise unraisable_exceptions.pop()
+
+    sys.unraisablehook = unraisablehook  # type: ignore
+
+    return cleanup