diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
|