diff --git a/tests/unittest.py b/tests/unittest.py
index 4aa7f56106..24077d79d6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -40,6 +40,7 @@ from typing import (
Mapping,
NoReturn,
Optional,
+ Protocol,
Tuple,
Type,
TypeVar,
@@ -50,7 +51,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
-from typing_extensions import Concatenate, ParamSpec, Protocol
+from typing_extensions import Concatenate, ParamSpec
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -272,8 +273,8 @@ class TestCase(unittest.TestCase):
def assertIncludes(
self,
- actual_items: AbstractSet[str],
- expected_items: AbstractSet[str],
+ actual_items: AbstractSet[TV],
+ expected_items: AbstractSet[TV],
exact: bool = False,
message: Optional[str] = None,
) -> None:
@@ -457,7 +458,9 @@ class HomeserverTestCase(TestCase):
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign]
self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign]
- self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign]
+ self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[method-assign]
+ return_value=token
+ )
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
@@ -779,7 +782,7 @@ class HomeserverTestCase(TestCase):
self,
username: str,
appservice_token: str,
- ) -> Tuple[str, str]:
+ ) -> Tuple[str, Optional[str]]:
"""Register an appservice user as an application service.
Requires the client-facing registration API be registered.
@@ -803,7 +806,7 @@ class HomeserverTestCase(TestCase):
access_token=appservice_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["user_id"], channel.json_body["device_id"]
+ return channel.json_body["user_id"], channel.json_body.get("device_id")
def login(
self,
|