summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8690.misc1
-rw-r--r--tests/test_utils/__init__.py34
-rw-r--r--tests/unittest.py6
3 files changed, 39 insertions, 2 deletions
diff --git a/changelog.d/8690.misc b/changelog.d/8690.misc
new file mode 100644
index 0000000000..0f38ba1f5d
--- /dev/null
+++ b/changelog.d/8690.misc
@@ -0,0 +1 @@
+Fail tests if they do not await coroutines.
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
 """
 Utilities for running the unit tests
 """
+import sys
+import warnings
 from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
     future = Future()  # type: ignore
     future.set_result(result)
     return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+    """
+    Convert warnings from a non-awaited coroutines into errors.
+    """
+    warnings.simplefilter("error", RuntimeWarning)
+
+    # unraisablehook was added in Python 3.8.
+    if not hasattr(sys, "unraisablehook"):
+        return lambda: None
+
+    # State shared between unraisablehook and check_for_unraisable_exceptions.
+    unraisable_exceptions = []
+    orig_unraisablehook = sys.unraisablehook  # type: ignore
+
+    def unraisablehook(unraisable):
+        unraisable_exceptions.append(unraisable.exc_value)
+
+    def cleanup():
+        """
+        A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+        """
+        sys.unraisablehook = orig_unraisablehook  # type: ignore
+        if unraisable_exceptions:
+            raise unraisable_exceptions.pop()
+
+    sys.unraisablehook = unraisablehook  # type: ignore
+
+    return cleanup
diff --git a/tests/unittest.py b/tests/unittest.py
index 257f465897..08cf9b10c5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -54,7 +54,7 @@ from tests.server import (
     render,
     setup_test_homeserver,
 )
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
 
                 logging.getLogger().setLevel(level)
 
+            # Trial messes with the warnings configuration, thus this has to be
+            # done in the context of an individual TestCase.
+            self.addCleanup(setup_awaitable_errors())
+
             return orig()
 
         @around(self)