diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..1a50c2acf1 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a generic event into a room
@@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id(
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
# limitations under the License.
from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
# a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name
self.hiddens[input_name] = attr_dict["value"]
- def error(_, message):
+ def error(self, message: str) -> NoReturn:
raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
)
-def setup_logging():
+def setup_logging() -> None:
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
|