summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/unittest.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/tests/unittest.py b/tests/unittest.py

index 4aa7f56106..24077d79d6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -40,6 +40,7 @@ from typing import ( Mapping, NoReturn, Optional, + Protocol, Tuple, Type, TypeVar, @@ -50,7 +51,7 @@ from unittest.mock import Mock, patch import canonicaljson import signedjson.key import unpaddedbase64 -from typing_extensions import Concatenate, ParamSpec, Protocol +from typing_extensions import Concatenate, ParamSpec from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure @@ -272,8 +273,8 @@ class TestCase(unittest.TestCase): def assertIncludes( self, - actual_items: AbstractSet[str], - expected_items: AbstractSet[str], + actual_items: AbstractSet[TV], + expected_items: AbstractSet[TV], exact: bool = False, message: Optional[str] = None, ) -> None: @@ -457,7 +458,9 @@ class HomeserverTestCase(TestCase): # Type ignore: mypy doesn't like us assigning to methods. self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] + self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[method-assign] + return_value=token + ) if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] @@ -779,7 +782,7 @@ class HomeserverTestCase(TestCase): self, username: str, appservice_token: str, - ) -> Tuple[str, str]: + ) -> Tuple[str, Optional[str]]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. @@ -803,7 +806,7 @@ class HomeserverTestCase(TestCase): access_token=appservice_token, ) self.assertEqual(channel.code, 200, channel.json_body) - return channel.json_body["user_id"], channel.json_body["device_id"] + return channel.json_body["user_id"], channel.json_body.get("device_id") def login( self,