diff --git a/tests/utils.py b/tests/utils.py
index 757320ebee..9fd26ef348 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,7 +21,20 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
+import signal
+from types import FrameType, TracebackType
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
from typing_extensions import Literal, ParamSpec
@@ -379,3 +392,30 @@ def checked_cast(type: Type[T], x: object) -> T:
"""
assert isinstance(x, type)
return x
+
+
+class TestTimeout(Exception):
+ pass
+
+
+class test_timeout:
+ def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
+ if error_message is None:
+ error_message = "test timed out after {}s.".format(seconds)
+ self.seconds = seconds
+ self.error_message = error_message
+
+ def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
+ raise TestTimeout(self.error_message)
+
+ def __enter__(self) -> None:
+ signal.signal(signal.SIGALRM, self.handle_timeout)
+ signal.alarm(self.seconds)
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ signal.alarm(0)
|