summary refs log tree commit diff
path: root/tests/module_api/test_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/module_api/test_api.py')
-rw-r--r--tests/module_api/test_api.py122
1 files changed, 72 insertions, 50 deletions
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..cc173ebda6 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, EventTypes
 from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
 
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+    """Common properties of the two test case classes."""
+
+    module_api: ModuleApi
+
+    # These are all written by _test_sending_local_online_presence_to_local_user.
+    presence_receiver_id: str
+    presence_receiver_tok: str
+    presence_sender_id: str
+    presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -42,14 +59,14 @@ class ModuleApiTestCase(HomeserverTestCase):
         notifications.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
-        self.module_api = homeserver.get_module_api()
-        self.event_creation_handler = homeserver.get_event_creation_handler()
-        self.sync_handler = homeserver.get_sync_handler()
-        self.auth_handler = homeserver.get_auth_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.module_api = hs.get_module_api()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.sync_handler = hs.get_sync_handler()
+        self.auth_handler = hs.get_auth_handler()
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
         fed_transport_client.send_transaction = simple_async_mock({})
@@ -58,7 +75,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             federation_transport_client=fed_transport_client,
         )
 
-    def test_can_register_user(self):
+    def test_can_register_user(self) -> None:
         """Tests that an external module can register a user"""
         # Register a new user
         user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
-    def test_can_register_admin_user(self):
+    def test_can_register_admin_user(self) -> None:
         user_id = self.register_user(
             "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
         )
 
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_admin(self):
+    def test_can_set_admin(self) -> None:
         user_id = self.register_user(
             "alice_wants_admin",
             "1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.get_success(self.module_api.set_user_admin(user_id, True))
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_displayname(self):
+    def test_can_set_displayname(self) -> None:
         localpart = "alice_wants_a_new_displayname"
         user_id = self.register_user(
             localpart, "1234", displayname="Alice", admin=False
         )
         found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+        assert found_userinfo is not None
         self.get_success(
             self.module_api.set_displayname(
                 found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertEqual(found_profile.display_name, "Bob")
 
-    def test_get_userinfo_by_id(self):
+    def test_get_userinfo_by_id(self) -> None:
         user_id = self.register_user("alice", "1234")
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, False)
 
-    def test_get_userinfo_by_id__no_user_found(self):
+    def test_get_userinfo_by_id__no_user_found(self) -> None:
         found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
         self.assertIsNone(found_user)
 
-    def test_get_user_ip_and_agents(self):
+    def test_get_user_ip_and_agents(self) -> None:
         user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
 
         # Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # we should only find the second ip, agent.
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
-                user_id, (last_seen_1 + last_seen_2) / 2
+                user_id, (last_seen_1 + last_seen_2) // 2
             )
         )
         self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_get_user_ip_and_agents__no_user_found(self):
+    def test_get_user_ip_and_agents__no_user_found(self) -> None:
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
                 "@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_sending_events_into_room(self):
+    def test_sending_events_into_room(self) -> None:
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
-        self.event_creation_handler.create_and_send_nonmember_event = Mock(
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[assignment]
             spec=[],
             side_effect=self.event_creation_handler.create_and_send_nonmember_event,
         )
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         room_id = self.helper.create_room_as(user_id, tok=tok)
 
         # Create and send a non-state event
-        content = {"body": "I am a puppet", "msgtype": "m.text"}
+        content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
         event_dict = {
             "room_id": room_id,
             "type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             "sender": user_id,
             "state_key": "",
         }
-        event: EventBase = self.get_success(
+        event = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
         )
         self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.create_and_send_event_into_room(event_dict), Exception
         )
 
-    def test_public_rooms(self):
+    def test_public_rooms(self) -> None:
         """Tests that a room can be added and removed from the public rooms list,
         as well as have its public rooms directory state queried.
         """
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertFalse(is_in_public_rooms)
 
-    def test_send_local_online_presence_to(self):
+    def test_send_local_online_presence_to(self) -> None:
         # Test sending local online presence to users from the main process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
 
     # Enable federation sending on the main process.
     @override_config({"federation_sender_instances": None})
-    def test_send_local_online_presence_to_federation(self):
+    def test_send_local_online_presence_to_federation(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to remote users."""
         # Create a user who will send presence updates
         self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -431,7 +451,7 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertTrue(found_update)
 
-    def test_update_membership(self):
+    def test_update_membership(self) -> None:
         """Tests that the module API can update the membership of a user in a room."""
         peter = self.register_user("peter", "hackme")
         lesley = self.register_user("lesley", "hackme")
@@ -554,7 +574,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertEqual(res["displayname"], "simone")
         self.assertIsNone(res["avatar_url"])
 
-    def test_update_room_membership_remote_join(self):
+    def test_update_room_membership_remote_join(self) -> None:
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
@@ -582,7 +602,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that a remote join was attempted.
         self.assertEqual(mocked_remote_join.call_count, 1)
 
-    def test_get_room_state(self):
+    def test_get_room_state(self) -> None:
         """Tests that a module can retrieve the state of a room through the module API."""
         user_id = self.register_user("peter", "hackme")
         tok = self.login("peter", "hackme")
@@ -677,7 +697,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.check_push_rule_actions(["foo"])
 
         with self.assertRaises(InvalidRuleException):
-            self.module_api.check_push_rule_actions({"foo": "bar"})
+            self.module_api.check_push_rule_actions([{"foo": "bar"}])
 
         self.module_api.check_push_rule_actions(["notify"])
 
@@ -756,7 +776,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertIsNone(room_alias)
 
 
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
     """For testing ModuleApi functionality in a multi-worker setup"""
 
     servlets = [
@@ -766,7 +786,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         presence.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         conf = super().default_config()
         conf["stream_writers"] = {"presence": ["presence_writer"]}
         conf["instance_map"] = {
@@ -774,18 +794,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         }
         return conf
 
-    def prepare(self, reactor, clock, homeserver):
-        self.module_api = homeserver.get_module_api()
-        self.sync_handler = homeserver.get_sync_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.module_api = hs.get_module_api()
+        self.sync_handler = hs.get_sync_handler()
 
-    def test_send_local_online_presence_to_workers(self):
+    def test_send_local_online_presence_to_workers(self) -> None:
         # Test sending local online presence to users from a worker process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
 
 
 def _test_sending_local_online_presence_to_local_user(
-    test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+    test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
     """Tests that send_local_presence_to_users sends local online presence to local users.
 
     This simultaneously tests two different usecases:
@@ -852,6 +872,7 @@ def _test_sending_local_online_presence_to_local_user(
         # Replicate the current sync presence token from the main process to the worker process.
         # We need to do this so that the worker process knows the current presence stream ID to
         # insert into the database when we call ModuleApi.send_local_online_presence_to.
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         test_case.replicate()
 
     # Syncing again should result in no presence updates
@@ -868,6 +889,7 @@ def _test_sending_local_online_presence_to_local_user(
 
     # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
     if test_with_workers:
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         module_api_to_use = worker_hs.get_module_api()
     else:
         module_api_to_use = test_case.module_api
@@ -875,12 +897,11 @@ def _test_sending_local_online_presence_to_local_user(
     # Trigger sending local online presence. We expect this information
     # to be saved to the database where all processes can access it.
     # Note that we're syncing via the master.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [test_case.presence_receiver_id],
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the
@@ -897,7 +918,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -908,7 +929,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -936,12 +957,13 @@ def _test_sending_local_online_presence_to_local_user(
     test_case.assertEqual(len(presence_updates), 1)
 
     # Now trigger sending local online presence.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [
+                test_case.presence_receiver_id,
+            ]
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the