diff --git a/changelog.d/15027.misc b/changelog.d/15027.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15027.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/mypy.ini b/mypy.ini
index 1bdeb18d94..70c106c668 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -69,27 +69,9 @@ disallow_untyped_defs = False
[mypy-tests.server_notices.test_resource_limits_server_notices]
disallow_untyped_defs = False
-[mypy-tests.test_distributor]
-disallow_untyped_defs = False
-
-[mypy-tests.test_event_auth]
-disallow_untyped_defs = False
-
[mypy-tests.test_federation]
disallow_untyped_defs = False
-[mypy-tests.test_mau]
-disallow_untyped_defs = False
-
-[mypy-tests.test_rust]
-disallow_untyped_defs = False
-
-[mypy-tests.test_test_utils]
-disallow_untyped_defs = False
-
-[mypy-tests.test_types]
-disallow_untyped_defs = False
-
[mypy-tests.test_utils.*]
disallow_untyped_defs = False
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dist = Distributor()
- def test_signal_dispatch(self):
+ def test_signal_dispatch(self) -> None:
self.dist.declare("alert")
observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- def test_signal_catch(self):
+ def test_signal_catch(self) -> None:
self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
- def test_signal_prereg(self):
+ def test_signal_prereg(self) -> None:
observer = Mock()
self.dist.observe("flare", observer)
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
- def test_signal_undeclared(self):
- def code():
+ def test_signal_undeclared(self) -> None:
+ def code() -> None:
self.dist.fire("notification")
self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore"""
- def __init__(self):
+ def __init__(self) -> None:
self._store: Dict[str, EventBase] = {}
- def add_event(self, event: EventBase):
+ def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event
- def add_events(self, events: Iterable[EventBase]):
+ def add_events(self, events: Iterable[EventBase]) -> None:
for event in events:
self._store[event.event_id] = event
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase):
- def test_rejected_auth_events(self):
+ def test_rejected_auth_events(self) -> None:
"""
Events that refer to rejected events in their auth events are rejected
"""
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
)
)
- def test_create_event_with_prev_events(self):
+ def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_duplicate_auth_events(self):
+ def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2)
)
- def test_unexpected_auth_events(self):
+ def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_random_users_cannot_send_state_before_first_pl(self):
+ def test_random_users_cannot_send_state_before_first_pl(self) -> None:
"""
Check that, before the first PL lands, the creator is the only user
that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_state_default_level(self):
+ def test_state_default_level(self) -> None:
"""
Check that users above the state_default level can send state and
those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_msc2432_alias_event(self):
+ def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
)
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
- def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+ def test_notifications(
+ self, room_version: RoomVersion, allow_modification: bool
+ ) -> None:
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
- def test_join_rules_public(self):
+ def test_join_rules_public(self) -> None:
"""
Test joining a public room.
"""
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
- def test_join_rules_invite(self):
+ def test_join_rules_invite(self) -> None:
"""
Test joining an invite only room.
"""
@@ -835,7 +837,7 @@ def _power_levels_event(
)
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths."""
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_simple_deny_mau(self):
+ def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_as_ignores_mau(self):
+ def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True)
- def test_allowed_after_a_month_mau(self):
+ def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
- def test_trial_delay(self):
+ def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
- def test_trial_users_cant_come_back(self):
+ def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
- def test_tracked_but_not_limited(self):
+ def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
}
)
- def test_as_trial_days(self):
+ def test_as_trial_days(self) -> None:
user_tokens: List[str] = []
- def advance_time_and_sync():
+ def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61)
for token in user_tokens:
self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
},
)
- def create_user(self, localpart, token=None, appservice=False):
+ def create_user(
+ self, localpart: str, token: Optional[str] = None, appservice: bool = False
+ ) -> str:
request_data = {
"username": localpart,
"password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
- def do_sync_for_user(self, token):
+ def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code."""
- def test_basic(self):
+ def test_basic(self) -> None:
result = sum_as_string(1, 2)
self.assertEqual("3", result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
- def test_advance_time(self):
+ def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
- def test_later(self):
+ def test_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1])
- def test_cancel_later(self):
+ def test_cancel_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_parse_rejects_empty_id(self):
+ def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
- def test_parse_rejects_missing_sigil(self):
+ def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
- def test_parse_rejects_missing_separator(self):
+ def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
- def test_validation_rejects_missing_domain(self):
+ def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
- def test_build(self):
+ def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
- def test_compare(self):
+ def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
- def test_build(self):
+ def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
- def test_validate(self):
+ def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
- def testPassThrough(self):
+ def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
- def testUpperCase(self):
+ def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
- def testSymbols(self):
+ def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
- def testLeadingUnderscore(self):
+ def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
- def testNonAscii(self):
+ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
|