diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 53d49ca896..3b72c4c9d0 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
- ) -> HomeServer:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass")
- return hs
-
def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
@@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
- ) -> HomeServer:
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()
- return hs
-
def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 8a0bb91f40..745750b1d7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,6 +14,7 @@
import logging
from typing import cast
from unittest import TestCase
+from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -22,6 +23,7 @@ from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseErro
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.federation.federation_client import SendJoinResult
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -30,7 +32,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, make_awaitable
logger = logging.getLogger(__name__)
@@ -280,13 +282,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(
+ state_handler.compute_event_context(
+ event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in current_state
+ },
+ partial_state=False,
+ )
+ )
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
event,
- state_ids={
- (e.type, e.state_key): e.event_id for e in current_state
- },
+ context,
)
)
@@ -448,3 +458,121 @@ class EventFromPduTestCase(TestCase):
},
RoomVersions.V6,
)
+
+
+class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
+ def test_failed_partial_join_is_clean(self) -> None:
+ """
+ Tests that, when failing to partial-join a room, we don't get stuck with
+ a partial-state flag on a room.
+ """
+
+ fed_handler = self.hs.get_federation_handler()
+ fed_client = fed_handler.federation_client
+
+ room_id = "!room:example.com"
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ },
+ RoomVersions.V10,
+ )
+
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
+
+ EVENT_CREATE = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.create",
+ "sender": "@kristina:example.com",
+ "state_key": "",
+ "depth": 0,
+ "content": {"creator": "@kristina:example.com", "room_version": "10"},
+ "auth_events": [],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@kristina:example.com",
+ "content": {"membership": "join"},
+ "depth": 1,
+ "prev_events": [EVENT_CREATE.event_id],
+ "auth_events": [EVENT_CREATE.event_id],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@kristina:example.com",
+ "state_key": "@alice:test",
+ "content": {"membership": "invite"},
+ "depth": 2,
+ "prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
+ "auth_events": [
+ EVENT_CREATE.event_id,
+ EVENT_CREATOR_MEMBERSHIP.event_id,
+ ],
+ "origin_server_ts": 1,
+ },
+ room_version=RoomVersions.V10,
+ )
+ mock_send_join = Mock(
+ return_value=make_awaitable(
+ SendJoinResult(
+ membership_event,
+ "example.com",
+ state=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ auth_chain=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ partial_state=True,
+ servers_in_room=["example.com"],
+ )
+ )
+ )
+
+ with patch.object(
+ fed_client, "make_membership_event", mock_make_membership_event
+ ), patch.object(fed_client, "send_join", mock_send_join):
+ # Join and check that our join event is rejected
+ # (The join event is rejected because it doesn't have any signatures)
+ join_exc = self.get_failure(
+ fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
+ SynapseError,
+ )
+ self.assertIn("Join event was rejected", str(join_exc))
+
+ store = self.hs.get_datastores().main
+
+ # Check that we don't have a left-over partial_state entry.
+ self.assertFalse(
+ self.get_success(store.is_partial_state_room(room_id)),
+ f"Stale partial-stated room flag left over for {room_id} after a"
+ f" failed do_invite_join!",
+ )
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 6f77b1237c..da4bf8b582 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -23,7 +23,7 @@ from twisted.internet.defer import ensureDeferred
from twisted.mail import interfaces, smtp
from tests.server import FakeTransport
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
@implementer(interfaces.IMessageDelivery)
@@ -110,3 +110,58 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+
+ @override_config(
+ {
+ "email": {
+ "notif_from": "noreply@test",
+ "force_tls": True,
+ },
+ }
+ )
+ def test_send_email_force_tls(self):
+ """Happy-path test that we can send email to an Implicit TLS server."""
+ h = self.hs.get_send_email_handler()
+ d = ensureDeferred(
+ h.send_email(
+ "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
+ )
+ )
+ # there should be an attempt to connect to localhost:465
+ self.assertEqual(len(self.reactor.sslClients), 1)
+ (
+ host,
+ port,
+ client_factory,
+ contextFactory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.sslClients[0]
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 465)
+
+ # wire it up to an SMTP server
+ message_delivery = _DummyMessageDelivery()
+ server_protocol = smtp.ESMTP()
+ server_protocol.delivery = message_delivery
+ # make sure that the server uses the test reactor to set timeouts
+ server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
+
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
+ server_protocol.makeConnection(
+ FakeTransport(
+ client_protocol,
+ self.reactor,
+ peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ )
+ )
+
+ # the message should now get delivered
+ self.get_success(d, by=0.1)
+
+ # check it arrived
+ self.assertEqual(len(message_delivery.messages), 1)
+ user, msg = message_delivery.messages.pop()
+ self.assertEqual(str(user), "foo@bar.com")
+ self.assertIn(b"Subject: test subject", msg)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 2526136ff8..623883b53c 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1873,7 +1873,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
- self.assertEqual("No known servers", channel.json_body["error"])
+ self.assertEqual(
+ "Can't join remote room because no servers that are in the room have been provided.",
+ channel.json_body["error"],
+ )
def test_room_is_not_valid(self) -> None:
"""
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 071b488cc0..f8e64ce6ac 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -586,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
+ "email": "https://id_server",
"msisdn": "https://id_server",
},
- "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ad03eee17b..d589f07314 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user2_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
return assert_thread
@@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
# Check the unsigned field on the latest event.
self.assert_dict(
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c45cb32090..aa2f578441 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual(
- [state_event["type"] for state_event in channel.json_body],
+ [state_event["type"] for state_event in channel.json_list],
{
"m.room.create",
"m.room.power_levels",
@@ -2070,7 +2070,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
- config["experimental_features"] = {"msc3827_enabled": True}
self.hs = self.setup_test_homeserver(config=config)
self.url = b"/_matrix/client/r0/publicRooms"
@@ -2123,13 +2122,13 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1)
- self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None)
+ self.assertEqual(chunk[0].get("room_type", None), None)
def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1)
- self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space")
+ self.assertEqual(chunk[0].get("room_type", None), "m.space")
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 9a48e9286f..18a7195409 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
@@ -185,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
- def error_dict(self) -> JsonDict:
+ def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
"""
- result = super().error_dict()
+ result = super().error_dict(config)
result["nasty"] = "very"
return result
diff --git a/tests/server.py b/tests/server.py
index df3f1564c9..9689e6a0cd 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -25,6 +25,7 @@ from typing import (
Callable,
Dict,
Iterable,
+ List,
MutableMapping,
Optional,
Tuple,
@@ -121,7 +122,15 @@ class FakeChannel:
@property
def json_body(self) -> JsonDict:
- return json.loads(self.text_body)
+ body = json.loads(self.text_body)
+ assert isinstance(body, dict)
+ return body
+
+ @property
+ def json_list(self) -> List[JsonDict]:
+ body = json.loads(self.text_body)
+ assert isinstance(body, list)
+ return body
@property
def text_body(self) -> str:
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 2ff88e64a5..3ce4f35cb7 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, state_ids_before_event=state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event=state,
+ partial_state=None if state is None else False,
+ )
)
self.get_success(self._persistence.persist_event(event, context))
@@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.state.compute_event_context(
remote_event_2,
state_ids_before_event=state_before_gap,
+ partial_state=False,
)
)
diff --git a/tests/test_state.py b/tests/test_state.py
index bafd6d1750..504530b49a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -462,6 +462,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
@@ -492,6 +493,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
+ partial_state=False,
)
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 66ce92f4a6..bec4a3d023 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -28,6 +28,7 @@ from typing import (
Generic,
Iterable,
List,
+ NoReturn,
Optional,
Tuple,
Type,
@@ -39,7 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -67,7 +68,7 @@ from synapse.logging.context import (
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -88,6 +89,10 @@ setup_logging()
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+P = ParamSpec("P")
+R = TypeVar("R")
+S = TypeVar("S")
+
class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@@ -97,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
...
-def around(target):
+def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.
@@ -106,11 +111,11 @@ def around(target):
return orig(*args, **kwargs)
"""
- def _around(code):
+ def _around(code: Callable[Concatenate[S, P], R]) -> None:
name = code.__name__
orig = getattr(target, name)
- def new(*args, **kwargs):
+ def new(*args: P.args, **kwargs: P.kwargs) -> R:
return code(orig, *args, **kwargs)
setattr(target, name, new)
@@ -131,7 +136,7 @@ class TestCase(unittest.TestCase):
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
- def setUp(orig):
+ def setUp(orig: Callable[[], R]) -> R:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
@@ -144,7 +149,7 @@ class TestCase(unittest.TestCase):
if level is not None and old_level != level:
@around(self)
- def tearDown(orig):
+ def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
@@ -158,7 +163,7 @@ class TestCase(unittest.TestCase):
return orig()
@around(self)
- def tearDown(orig):
+ def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
@@ -167,7 +172,7 @@ class TestCase(unittest.TestCase):
return ret
- def assertObjectHasAttributes(self, attrs, obj):
+ def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual."""
for key in attrs.keys():
@@ -178,12 +183,12 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e
- def assert_dict(self, required, actual):
+ def assert_dict(self, required: dict, actual: dict) -> None:
"""Does a partial assert of a dict.
Args:
- required (dict): The keys and value which MUST be in 'actual'.
- actual (dict): The test result. Extra keys will not be checked.
+ required: The keys and value which MUST be in 'actual'.
+ actual: The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEqual(
@@ -191,31 +196,31 @@ class TestCase(unittest.TestCase):
)
-def DEBUG(target):
+def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.DEBUG
+ target.loglevel = logging.DEBUG # type: ignore[attr-defined]
return target
-def INFO(target):
+def INFO(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.INFO
+ target.loglevel = logging.INFO # type: ignore[attr-defined]
return target
-def logcontext_clean(target):
+def logcontext_clean(target: TV) -> TV:
"""A decorator which marks the TestCase or method as 'logcontext_clean'
... ie, any logcontext errors should cause a test failure
"""
- def logcontext_error(msg):
+ def logcontext_error(msg: str) -> NoReturn:
raise AssertionError("logcontext error: %s" % (msg))
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
- return patcher(target)
+ return patcher(target) # type: ignore[call-overload]
class HomeserverTestCase(TestCase):
@@ -255,7 +260,7 @@ class HomeserverTestCase(TestCase):
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)
- def setUp(self):
+ def setUp(self) -> None:
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
@@ -306,7 +311,9 @@ class HomeserverTestCase(TestCase):
)
)
- async def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(
+ token: Optional[str] = None, allow_guest: bool = False
+ ) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
@@ -314,7 +321,11 @@ class HomeserverTestCase(TestCase):
"is_guest": False,
}
- async def get_user_by_req(request, allow_guest=False):
+ async def get_user_by_req(
+ request: SynapseRequest,
+ allow_guest: bool = False,
+ allow_expired: bool = False,
+ ) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
@@ -339,11 +350,11 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
- def tearDown(self):
+ def tearDown(self) -> None:
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False
- def wait_on_thread(self, deferred, timeout=10):
+ def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
@@ -374,7 +385,7 @@ class HomeserverTestCase(TestCase):
clock (synapse.util.Clock): The Clock, associated with the reactor.
Returns:
- A homeserver (synapse.server.HomeServer) suitable for testing.
+ A homeserver suitable for testing.
Function to be overridden in subclasses.
"""
@@ -408,7 +419,7 @@ class HomeserverTestCase(TestCase):
"/_synapse/admin": servlet_resource,
}
- def default_config(self):
+ def default_config(self) -> JsonDict:
"""
Get a default HomeServer config dict.
"""
@@ -421,7 +432,9 @@ class HomeserverTestCase(TestCase):
return config
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
@@ -519,7 +532,7 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
- async def run_bg_updates():
+ async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))
@@ -538,11 +551,7 @@ class HomeserverTestCase(TestCase):
"""
self.reactor.pump([by] * 100)
- def get_success(
- self,
- d: Awaitable[TV],
- by: float = 0.0,
- ) -> TV:
+ def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
@@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# poke the other server's signing key into the key store, so that we don't
@@ -879,7 +888,7 @@ def _auth_header_for_request(
)
-def override_config(extra_config):
+def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
"""A decorator which can be applied to test functions to give additional HS config
For use
@@ -892,12 +901,13 @@ def override_config(extra_config):
...
Args:
- extra_config(dict): Additional config settings to be merged into the default
+ extra_config: Additional config settings to be merged into the default
config dict before instantiating the test homeserver.
"""
- def decorator(func):
- func._extra_config = extra_config
+ def decorator(func: TV) -> TV:
+ # This attribute is being defined.
+ func._extra_config = extra_config # type: ignore[attr-defined]
return func
return decorator
|