diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..741bb6464a 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
@attr.s
@@ -470,7 +474,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@@ -491,7 +495,7 @@ def send_presence_update(
def sync_presence(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..5868eb2da7 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index adddbd002f..951caaa6b3 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
- self.hs_patcher.start()
+ self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def tearDown(self) -> None:
- self.hs_patcher.stop()
+ self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..e9be5fb504 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(
self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..cc173ebda6 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+ """Common properties of the two test case classes."""
+
+ module_api: ModuleApi
+
+ # These are all written by _test_sending_local_online_presence_to_local_user.
+ presence_receiver_id: str
+ presence_receiver_tok: str
+ presence_sender_id: str
+ presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -42,14 +59,14 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
- self.store = homeserver.get_datastores().main
- self.module_api = homeserver.get_module_api()
- self.event_creation_handler = homeserver.get_event_creation_handler()
- self.sync_handler = homeserver.get_sync_handler()
- self.auth_handler = homeserver.get_auth_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.module_api = hs.get_module_api()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.sync_handler = hs.get_sync_handler()
+ self.auth_handler = hs.get_auth_handler()
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
@@ -58,7 +75,7 @@ class ModuleApiTestCase(HomeserverTestCase):
federation_transport_client=fed_transport_client,
)
- def test_can_register_user(self):
+ def test_can_register_user(self) -> None:
"""Tests that an external module can register a user"""
# Register a new user
user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
- def test_can_register_admin_user(self):
+ def test_can_register_admin_user(self) -> None:
user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_admin(self):
+ def test_can_set_admin(self) -> None:
user_id = self.register_user(
"alice_wants_admin",
"1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_displayname(self):
+ def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname"
user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False
)
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+ assert found_userinfo is not None
self.get_success(
self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob")
- def test_get_userinfo_by_id(self):
+ def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
- def test_get_userinfo_by_id__no_user_found(self):
+ def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
- def test_get_user_ip_and_agents(self):
+ def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent.
info = self.get_success(
self.module_api.get_user_ip_and_agents(
- user_id, (last_seen_1 + last_seen_2) / 2
+ user_id, (last_seen_1 + last_seen_2) // 2
)
)
self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_get_user_ip_and_agents__no_user_found(self):
+ def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success(
self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_sending_events_into_room(self):
+ def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
- self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event,
)
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event
- content = {"body": "I am a puppet", "msgtype": "m.text"}
+ content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = {
"room_id": room_id,
"type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
- event: EventBase = self.get_success(
+ event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
)
self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception
)
- def test_public_rooms(self):
+ def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried.
"""
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
- def test_send_local_online_presence_to(self):
+ def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
- def test_send_local_online_presence_to_federation(self):
+ def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -431,7 +451,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
- def test_update_membership(self):
+ def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme")
@@ -554,7 +574,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
- def test_update_room_membership_remote_join(self):
+ def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
@@ -582,7 +602,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
- def test_get_room_state(self):
+ def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
@@ -677,7 +697,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
- self.module_api.check_push_rule_actions({"foo": "bar"})
+ self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"])
@@ -756,7 +776,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias)
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
servlets = [
@@ -766,7 +786,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
@@ -774,18 +794,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def prepare(self, reactor, clock, homeserver):
- self.module_api = homeserver.get_module_api()
- self.sync_handler = homeserver.get_sync_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.module_api = hs.get_module_api()
+ self.sync_handler = hs.get_sync_handler()
- def test_send_local_online_presence_to_workers(self):
+ def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user(
- test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+ test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases:
@@ -852,6 +872,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to.
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate()
# Syncing again should result in no presence updates
@@ -868,6 +889,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers:
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api()
else:
module_api_to_use = test_case.module_api
@@ -875,12 +897,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it.
# Note that we're syncing via the master.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [test_case.presence_receiver_id],
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
@@ -897,7 +918,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -908,7 +929,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -936,12 +957,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index da33423871..6603447341 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -48,8 +48,16 @@ class FlattenDictTestCase(unittest.TestCase):
input = {"foo": {"bar": "abc"}}
self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
+ # If a field has a dot in it, escape it.
+ input = {"m.foo": {"b\\ar": "abc"}}
+ self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
+ self.assertEqual(
+ {"m\\.foo.b\\\\ar": "abc"},
+ _flatten_dict(input, msc3783_escape_event_match_key=True),
+ )
+
def test_non_string(self) -> None:
- """Non-string items are dropped."""
+ """Booleans, ints, and nulls should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
@@ -58,7 +66,9 @@ class FlattenDictTestCase(unittest.TestCase):
"fuzz": [],
"boo": {},
}
- self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+ self.assertEqual(
+ {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input)
+ )
def test_event(self) -> None:
"""Events can also be flattened."""
@@ -78,9 +88,9 @@ class FlattenDictTestCase(unittest.TestCase):
)
expected = {
"content.msgtype": "m.text",
- "content.body": "hello world!",
+ "content.body": "Hello world!",
"content.format": "org.matrix.custom.html",
- "content.formatted_body": "<h1>hello world!</h1>",
+ "content.formatted_body": "<h1>Hello world!</h1>",
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
@@ -158,6 +168,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
+ msc3758_exact_event_match=True,
)
def test_display_name(self) -> None:
@@ -402,6 +413,142 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
+ def test_exact_event_match_string(self) -> None:
+ """Check that exact_event_match conditions work as expected for strings."""
+
+ # Test against a string value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": "foobaz"},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "FoobaZ"},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "test foobaz test"},
+ "values must exactly match",
+ )
+ value: Any
+ for value in (True, False, 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ frozendict.frozendict({"value": "foobaz"}),
+ "values should match on frozendicts",
+ )
+
+ def test_exact_event_match_boolean(self) -> None:
+ """Check that exact_event_match conditions work as expected for booleans."""
+
+ # Test against a True boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": True,
+ }
+ self._assert_matches(
+ condition,
+ {"value": True},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": False},
+ "incorrect values should not match",
+ )
+ for value in ("foobaz", 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # Test against a False boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": False,
+ }
+ self._assert_matches(
+ condition,
+ {"value": False},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": True},
+ "incorrect values should not match",
+ )
+ # Choose false-y values to ensure there's no type coercion.
+ for value in ("", 0, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_null(self) -> None:
+ """Check that exact_event_match conditions work as expected for null."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": None,
+ }
+ self._assert_matches(
+ condition,
+ {"value": None},
+ "exact value should match",
+ )
+ for value in ("foobaz", True, False, 1, 1.1, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_integer(self) -> None:
+ """Check that exact_event_match conditions work as expected for integers."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": 1,
+ }
+ self._assert_matches(
+ condition,
+ {"value": 1},
+ "exact value should match",
+ )
+ value: Any
+ for value in (1.1, -1, 0):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect values should not match",
+ )
+ for value in ("1", True, False, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..db77a45ae3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+ self.url = "/_synapse/admin/v1/media/delete"
+ self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
@@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ @parameterized.expand([(True,), (False,)])
+ def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
"""
+ url = self.legacy_url if use_legacy_url else self.url
# upload and do not access
server_and_media_id = self._create_media()
@@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
now_ms = self.clock.time_msec()
channel = self.make_request(
"POST",
- self.url + "?before_ts=" + str(now_ms),
+ url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..f71ff46d87 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..b50406e129 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib import parse
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
-
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.fetches = []
+ self.fetches: List[
+ Tuple[
+ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ ignore_backoff: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
- def write_to(r):
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
- d = Deferred()
- d.addCallback(write_to)
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallback(write_to)
+ return make_deferred_yieldable(d_after_callback)
+ # Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.get_file = get_file
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Synapse should regenerate missing thumbnails.
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ assert info is not None
file_id = info["filesystem_id"]
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
},
{
"thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
},
],
- file_id=f"image{self.test_image.extension}",
+ file_id=f"image{self.test_image.extension.decode()}",
url_cache=None,
server_name=None,
)
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
self.config = config
self.api = api
+ @staticmethod
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
from tests.unittest import TestCase
class RegisterTestCase(TestCase):
- def test_success(self):
+ def test_success(self) -> None:
"""
The script will fetch a nonce, and then generate a MAC with it, and then
post that MAC.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
# sys.exit shouldn't have been called.
self.assertEqual(err_code, [])
- def test_failure_nonce(self):
+ def test_failure_nonce(self) -> None:
"""
If the script fails to fetch a nonce, it throws an error and quits.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 404
r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
self.assertIn("ERROR! Received 404 Not Found", out)
self.assertNotIn("Success!", out)
- def test_failure_post(self):
+ def test_failure_post(self) -> None:
"""
The script will fetch a nonce, and then if the final POST fails, will
report an error and quit.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..6540ed53f1 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,8 +14,12 @@
import os
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
tmpdir = self.mktemp()
os.mkdir(tmpdir)
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"room_name": "Server Notices",
}
- hs = self.setup_test_homeserver(config=config)
-
- return hs
+ return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("bob", "abc123")
self.access_token = self.login("bob", "abc123")
- def test_get_sync_message(self):
+ def test_get_sync_message(self) -> None:
"""
When user consent server notices are enabled, a sync will cause a notice
to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..5b76383d76 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,7 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -33,7 +34,7 @@ from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True})
- def test_maybe_send_server_notice_disabled_hs(self):
+ def test_maybe_send_server_notice_disabled_hs(self) -> None:
"""If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@override_config({"limit_usage_by_mau": False})
- def test_maybe_send_server_notice_to_user_flag_off(self):
+ def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
"""If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth_blocking.check_auth_blocking = Mock(
@@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
@@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
"""
Test when user does not have blocked notice, but should have one
"""
@@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
- def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
@@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
"""
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
@@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+ self,
+ ) -> None:
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
@@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False})
- def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
@@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+ self,
+ ) -> None:
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
@@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
c = super().default_config()
c["server_notices"] = {
"system_mxid_localpart": "server",
@@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
- def test_server_notice_only_sent_once(self):
+ def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertEqual(count, 1)
- def test_no_invite_without_notice(self):
+ def test_no_invite_without_notice(self) -> None:
"""Tests that a user doesn't get invited to a server notices room without a
server notice being sent.
@@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
m.assert_called_once_with(user_id)
- def test_invite_with_notice(self):
+ def test_invite_with_notice(self) -> None:
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
"""
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, context_1 = self.get_success(
+ event_1, unpersisted_context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
+ context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, context_2 = self.get_success(
+ event_2, unpersisted_context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
+
+ context_2 = self.get_success(unpersisted_context_2.persist(event_2))
self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, context = self.get_success(
+ redaction_event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(redaction_event))
+
self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dist = Distributor()
- def test_signal_dispatch(self):
+ def test_signal_dispatch(self) -> None:
self.dist.declare("alert")
observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- def test_signal_catch(self):
+ def test_signal_catch(self) -> None:
self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
- def test_signal_prereg(self):
+ def test_signal_prereg(self) -> None:
observer = Mock()
self.dist.observe("flare", observer)
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
- def test_signal_undeclared(self):
- def code():
+ def test_signal_undeclared(self) -> None:
+ def code() -> None:
self.dist.fire("notification")
self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore"""
- def __init__(self):
+ def __init__(self) -> None:
self._store: Dict[str, EventBase] = {}
- def add_event(self, event: EventBase):
+ def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event
- def add_events(self, events: Iterable[EventBase]):
+ def add_events(self, events: Iterable[EventBase]) -> None:
for event in events:
self._store[event.event_id] = event
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase):
- def test_rejected_auth_events(self):
+ def test_rejected_auth_events(self) -> None:
"""
Events that refer to rejected events in their auth events are rejected
"""
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
)
)
- def test_create_event_with_prev_events(self):
+ def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_duplicate_auth_events(self):
+ def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2)
)
- def test_unexpected_auth_events(self):
+ def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_random_users_cannot_send_state_before_first_pl(self):
+ def test_random_users_cannot_send_state_before_first_pl(self) -> None:
"""
Check that, before the first PL lands, the creator is the only user
that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_state_default_level(self):
+ def test_state_default_level(self) -> None:
"""
Check that users above the state_default level can send state and
those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_msc2432_alias_event(self):
+ def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
)
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
- def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+ def test_notifications(
+ self, room_version: RoomVersion, allow_modification: bool
+ ) -> None:
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
- def test_join_rules_public(self):
+ def test_join_rules_public(self) -> None:
"""
Test joining a public room.
"""
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
- def test_join_rules_invite(self):
+ def test_join_rules_invite(self) -> None:
"""
Test joining an invite only room.
"""
@@ -835,7 +837,7 @@ def _power_levels_event(
)
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..ddb43c8c98 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional, Union
from unittest.mock import Mock
from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError
from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
- def setUp(self):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
- self.reactor = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.reactor)
- self.homeserver = setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.http_client,
- clock=self.hs_clock,
- reactor=self.reactor,
- )
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
user_id = UserID("us", "test")
our_user = create_requester(user_id)
- room_creator = self.homeserver.get_room_creation_handler()
+ room_creator = self.hs.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)[0]["room_id"]
- self.store = self.homeserver.get_datastores().main
+ self.store = self.hs.get_datastores().main
# Figure out what the most recent event is
most_recent = self.get_success(
- self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
- self.room_id
- )
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_federation_handler()
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ self.handler = self.hs.get_federation_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
- async def _check_event_auth(origin, event, context):
+ async def _check_event_auth(
+ origin: Optional[str], event: EventBase, context: EventContext
+ ) -> None:
pass
federation_event_handler._check_event_auth = _check_event_auth
- self.client = self.homeserver.get_federation_client()
+ self.client = self.hs.get_federation_client()
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
lambda dest, pdus, **k: succeed(pdus)
)
@@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"$join:test.serv",
)
- def test_cant_hide_direct_ancestors(self):
+ def test_cant_hide_direct_ancestors(self) -> None:
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
- async def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryParams] = None,
+ ) -> Union[JsonDict, list]:
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
+ return {}
self.http_client.post_json = post_json
@@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
- def test_retry_device_list_resync(self):
+ def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically.
"""
@@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully.
- def query_user_devices(destination, user_id):
+ def query_user_devices(
+ destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
if user_id == remote_user_id:
self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client.
- federation_client = self.homeserver.get_federation_client()
+ federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock(side_effect=query_user_devices)
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
- store = self.homeserver.get_datastores().main
+ store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
- device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ device_list_updater = self.hs.get_device_handler().device_list_updater
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
@@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)
- def test_cross_signing_keys_retry(self):
+ def test_cross_signing_keys_retry(self) -> None:
"""Tests that resyncing a device list correctly processes cross-signing keys from
the remote server.
"""
@@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
# Register mock device list retrieval on the federation client.
- federation_client = self.homeserver.get_federation_client()
+ federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock(
return_value=make_awaitable(
{
@@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Resync the device list.
- device_handler = self.homeserver.get_device_handler()
+ device_handler = self.hs.get_device_handler()
self.get_success(
device_handler.device_list_updater.user_device_resync(remote_user_id),
)
@@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
class StripUnsignedFromEventsTestCase(unittest.TestCase):
- def test_strip_unauthorized_unsigned_values(self):
+ def test_strip_unauthorized_unsigned_values(self) -> None:
event1 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
# Make sure unauthorized fields are stripped from unsigned
self.assertNotIn("more warez", filtered_event.unsigned)
- def test_strip_event_maintains_allowed_fields(self):
+ def test_strip_event_maintains_allowed_fields(self) -> None:
event2 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
self.assertIn("invite_room_state", filtered_event2.unsigned)
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
- def test_strip_event_removes_fields_based_on_event_type(self):
+ def test_strip_event_removes_fields_based_on_event_type(self) -> None:
event3 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths."""
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_simple_deny_mau(self):
+ def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_as_ignores_mau(self):
+ def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True)
- def test_allowed_after_a_month_mau(self):
+ def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
- def test_trial_delay(self):
+ def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
- def test_trial_users_cant_come_back(self):
+ def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
- def test_tracked_but_not_limited(self):
+ def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
}
)
- def test_as_trial_days(self):
+ def test_as_trial_days(self) -> None:
user_tokens: List[str] = []
- def advance_time_and_sync():
+ def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61)
for token in user_tokens:
self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
},
)
- def create_user(self, localpart, token=None, appservice=False):
+ def create_user(
+ self, localpart: str, token: Optional[str] = None, appservice: bool = False
+ ) -> str:
request_data = {
"username": localpart,
"password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
- def do_sync_for_user(self, token):
+ def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code."""
- def test_basic(self):
+ def test_basic(self) -> None:
result = sum_as_string(1, 2)
self.assertEqual("3", result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
- def test_advance_time(self):
+ def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
- def test_later(self):
+ def test_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1])
- def test_cancel_later(self):
+ def test_cancel_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_parse_rejects_empty_id(self):
+ def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
- def test_parse_rejects_missing_sigil(self):
+ def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
- def test_parse_rejects_missing_separator(self):
+ def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
- def test_validation_rejects_missing_domain(self):
+ def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
- def test_build(self):
+ def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
- def test_compare(self):
+ def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
- def test_build(self):
+ def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
- def test_validate(self):
+ def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
- def testPassThrough(self):
+ def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
- def testUpperCase(self):
+ def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
- def testSymbols(self):
+ def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
- def testLeadingUnderscore(self):
+ def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
- def testNonAscii(self):
+ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a generic event into a room
@@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- event, context = await hs.get_event_creation_handler().create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
+ context = await unpersisted_context.persist(event)
+
return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
# limitations under the License.
from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
# a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name
self.hiddens[input_name] = attr_dict["value"]
- def error(_, message):
+ def error(self, message: str) -> NoReturn:
raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
)
-def setup_logging():
+def setup_logging() -> None:
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..36d6b37aa4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
@@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
@@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
@@ -258,7 +261,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
- def test_out_of_band_invite_rejection(self):
+ def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index fa92dd94eb..68e59a88dc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase):
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
- async def get_requester(*args, **kwargs) -> Requester:
+ async def get_requester(*args: Any, **kwargs: Any) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
user_id=UserID.from_string(self.helper.auth_user_id),
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..15fabbc2d0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, context = await event_creation_handler.create_new_client_event(builder)
+ event, unpersisted_context = await event_creation_handler.create_new_client_event(
+ builder
+ )
+ context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
|