summary refs log tree commit diff
path: root/tests/events
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-01-25 15:14:03 -0500
committerGitHub <noreply@github.com>2023-01-25 15:14:03 -0500
commit3c3ba31507cbff27064ea3c6cf1e7add9583556a (patch)
treee75605d1d022d4eb60c7328eacf027e20ae151a7 /tests/events
parentDocument how to handle Dependabot pull requests. (#14916) (diff)
downloadsynapse-3c3ba31507cbff27064ea3c6cf1e7add9583556a.tar.xz
Add missing type hints for tests.events. (#14904)
Diffstat (limited to 'tests/events')
-rw-r--r--tests/events/test_presence_router.py58
-rw-r--r--tests/events/test_snapshot.py17
-rw-r--r--tests/events/test_utils.py71
3 files changed, 85 insertions, 61 deletions
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index b703e4472e..a9893def74 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
 
 import attr
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EduTypes
 from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
 from synapse.federation.units import Transaction
@@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, presence, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict, StreamToken, create_requester
+from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
 from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, override_config
 
 
 @attr.s
@@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -96,9 +103,7 @@ class PresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -118,9 +123,14 @@ class PresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
         fed_transport_client.send_transaction = simple_async_mock({})
@@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.sync_handler = self.hs.get_sync_handler()
         self.module_api = homeserver.get_module_api()
 
@@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_receiving_all_presence_legacy(self):
+    def test_receiving_all_presence_legacy(self) -> None:
         self.receiving_all_presence_test_body()
 
     @override_config(
@@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_receiving_all_presence(self):
+    def test_receiving_all_presence(self) -> None:
         self.receiving_all_presence_test_body()
 
-    def receiving_all_presence_test_body(self):
+    def receiving_all_presence_test_body(self) -> None:
         """Test that a user that does not share a room with another other can receive
         presence for them, due to presence routing.
         """
@@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_send_local_online_presence_to_with_module_legacy(self):
+    def test_send_local_online_presence_to_with_module_legacy(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
     @override_config(
@@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_send_local_online_presence_to_with_module(self):
+    def test_send_local_online_presence_to_with_module(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
-    def send_local_online_presence_to_with_module_test_body(self):
+    def send_local_online_presence_to_with_module_test_body(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to a set
         of specified local and remote users, with a custom PresenceRouter module enabled.
         """
@@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
                     continue
 
                 # EDUs can contain multiple presence updates
-                for presence_update in edu["content"]["push"]:
+                for presence_edu in edu["content"]["push"]:
                     # Check for presence updates that contain the user IDs we're after
-                    found_users.add(presence_update["user_id"])
+                    found_users.add(presence_edu["user_id"])
 
                     # Ensure that no offline states are being sent out
-                    self.assertNotEqual(presence_update["presence"], "offline")
+                    self.assertNotEqual(presence_edu["presence"], "offline")
 
         self.assertEqual(found_users, expected_users)
 
 
 def send_presence_update(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     access_token: str,
     presence_state: str,
@@ -479,7 +491,7 @@ def send_presence_update(
 
 
 def sync_presence(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     since_token: Optional[StreamToken] = None,
 ) -> Tuple[List[UserPresenceState], StreamToken]:
@@ -500,7 +512,7 @@ def sync_presence(
     requester = create_requester(user_id)
     sync_config = generate_sync_config(requester.user.to_string())
     sync_result = testcase.get_success(
-        testcase.sync_handler.wait_for_sync_for_user(
+        testcase.hs.get_sync_handler().wait_for_sync_for_user(
             requester, sync_config, since_token
         )
     )
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 8ddce83b83..6687c28e8f 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -12,9 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils.event_injection import create_event
@@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
 
@@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.user_tok = self.login("u1", "pass")
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
 
-    def test_serialize_deserialize_msg(self):
+    def test_serialize_deserialize_msg(self) -> None:
         """Test that an EventContext for a message event is the same after
         serialize/deserialize.
         """
@@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def test_serialize_deserialize_state_no_prev(self):
+    def test_serialize_deserialize_state_no_prev(self) -> None:
         """Test that an EventContext for a state event (with not previous entry)
         is the same after serialize/deserialize.
         """
@@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def test_serialize_deserialize_state_prev(self):
+    def test_serialize_deserialize_state_prev(self) -> None:
         """Test that an EventContext for a state event (which replaces a
         previous entry) is the same after serialize/deserialize.
         """
@@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def _check_serialize_deserialize(self, event, context):
+    def _check_serialize_deserialize(
+        self, event: EventBase, context: EventContext
+    ) -> None:
         serialized = self.get_success(context.serialize(event, self.store))
 
         d_context = EventContext.deserialize(self._storage_controllers, serialized)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index a79256846f..ff7b349d75 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -13,21 +13,24 @@
 # limitations under the License.
 
 import unittest as stdlib_unittest
+from typing import Any, List, Mapping, Optional
 
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import (
+    PowerLevelsContent,
     SerializeEventConfig,
     copy_and_fixup_power_levels_contents,
     maybe_upsert_event_field,
     prune_event,
     serialize_event,
 )
+from synapse.types import JsonDict
 from synapse.util.frozenutils import freeze
 
 
-def MockEvent(**kwargs):
+def MockEvent(**kwargs: Any) -> EventBase:
     if "event_id" not in kwargs:
         kwargs["event_id"] = "fake_event_id"
     if "type" not in kwargs:
@@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
 
 
 class PruneEventTestCase(stdlib_unittest.TestCase):
-    def run_test(self, evdict, matchdict, **kwargs):
+    def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None:
         """
         Asserts that a new event constructed with `evdict` will look like
         `matchdict` when it is redacted.
@@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
-    def test_minimal(self):
+    def test_minimal(self) -> None:
         self.run_test(
             {"type": "A", "event_id": "$test:domain"},
             {
@@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_basic_keys(self):
+    def test_basic_keys(self) -> None:
         """Ensure that the keys that should be untouched are kept."""
         # Note that some of the values below don't really make sense, but the
         # pruning of events doesn't worry about the values of any fields (with
@@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_unsigned(self):
+    def test_unsigned(self) -> None:
         """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
@@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_content(self):
+    def test_content(self) -> None:
         """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
@@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 },
             )
 
-    def test_create(self):
+    def test_create(self) -> None:
         """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
@@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_power_levels(self):
+    def test_power_levels(self) -> None:
         """Power level events keep a variety of content keys."""
         self.run_test(
             {
@@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_alias_event(self):
+    def test_alias_event(self) -> None:
         """Alias events have special behavior up through room version 6."""
         self.run_test(
             {
@@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
-    def test_redacts(self):
+    def test_redacts(self) -> None:
         """Redaction events have no special behaviour until MSC2174/MSC2176."""
 
         self.run_test(
@@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_join_rules(self):
+    def test_join_rules(self) -> None:
         """Join rules events have changed behavior starting with MSC3083."""
         self.run_test(
             {
@@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.V8,
         )
 
-    def test_member(self):
+    def test_member(self) -> None:
         """Member events have changed behavior starting with MSC3375."""
         self.run_test(
             {
@@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
 
 
 class SerializeEventTestCase(stdlib_unittest.TestCase):
-    def serialize(self, ev, fields):
+    def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
         return serialize_event(
             ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
         )
 
-    def test_event_fields_works_with_keys(self):
+    def test_event_fields_works_with_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
@@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"room_id": "!foo:bar"},
         )
 
-    def test_event_fields_works_with_nested_keys(self):
+    def test_event_fields_works_with_nested_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"body": "A message"}},
         )
 
-    def test_event_fields_works_with_dot_keys(self):
+    def test_event_fields_works_with_dot_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"key.with.dots": {}}},
         )
 
-    def test_event_fields_works_with_nested_dot_keys(self):
+    def test_event_fields_works_with_nested_dot_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"nested.dot.key": {"leaf.key": 42}}},
         )
 
-    def test_event_fields_nops_with_unknown_keys(self):
+    def test_event_fields_nops_with_unknown_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"foo": "bar"}},
         )
 
-    def test_event_fields_nops_with_non_dict_keys(self):
+    def test_event_fields_nops_with_non_dict_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {},
         )
 
-    def test_event_fields_nops_with_array_keys(self):
+    def test_event_fields_nops_with_array_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {},
         )
 
-    def test_event_fields_all_fields_if_empty(self):
+    def test_event_fields_all_fields_if_empty(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_event_fields_fail_if_fields_not_str(self):
+    def test_event_fields_fail_if_fields_not_str(self) -> None:
         with self.assertRaises(TypeError):
             self.serialize(
-                MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
+                MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]  # type: ignore[list-item]
             )
 
 
 class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
     def setUp(self) -> None:
-        self.test_content = {
+        self.test_content: PowerLevelsContent = {
             "ban": 50,
             "events": {"m.room.name": 100, "m.room.power_levels": 100},
             "events_default": 0,
@@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
             "users_default": 0,
         }
 
-    def _test(self, input):
+    def _test(self, input: PowerLevelsContent) -> None:
         a = copy_and_fixup_power_levels_contents(input)
 
         self.assertEqual(a["ban"], 50)
+        assert isinstance(a["events"], Mapping)
         self.assertEqual(a["events"]["m.room.name"], 100)
 
         # make sure that changing the copy changes the copy and not the orig
@@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
         a["events"]["m.room.power_levels"] = 20
 
         self.assertEqual(input["ban"], 50)
+        assert isinstance(input["events"], Mapping)
         self.assertEqual(input["events"]["m.room.power_levels"], 100)
 
-    def test_unfrozen(self):
+    def test_unfrozen(self) -> None:
         self._test(self.test_content)
 
-    def test_frozen(self):
+    def test_frozen(self) -> None:
         input = freeze(self.test_content)
         self._test(input)
 
-    def test_stringy_integers(self):
+    def test_stringy_integers(self) -> None:
         """String representations of decimal integers are converted to integers."""
-        input = {
+        input: PowerLevelsContent = {
             "a": "100",
             "b": {
                 "foo": 99,
@@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
 
     def test_invalid_types_raise_type_error(self) -> None:
         with self.assertRaises(TypeError):
-            copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]})  # type: ignore[arg-type]
-            copy_and_fixup_power_levels_contents({"a": None})  # type: ignore[arg-type]
+            copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]})  # type: ignore[dict-item]
+            copy_and_fixup_power_levels_contents({"a": None})  # type: ignore[dict-item]
 
     def test_invalid_nesting_raises_type_error(self) -> None:
         with self.assertRaises(TypeError):
-            copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})
+            copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})  # type: ignore[dict-item]