summary refs log tree commit diff
path: root/tests/events/test_presence_router.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/events/test_presence_router.py')
-rw-r--r--tests/events/test_presence_router.py58
1 files changed, 35 insertions, 23 deletions
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index b703e4472e..a9893def74 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
 
 import attr
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EduTypes
 from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
 from synapse.federation.units import Transaction
@@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, presence, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict, StreamToken, create_requester
+from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
 from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, override_config
 
 
 @attr.s
@@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -96,9 +103,7 @@ class PresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -118,9 +123,14 @@ class PresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
         fed_transport_client.send_transaction = simple_async_mock({})
@@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.sync_handler = self.hs.get_sync_handler()
         self.module_api = homeserver.get_module_api()
 
@@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_receiving_all_presence_legacy(self):
+    def test_receiving_all_presence_legacy(self) -> None:
         self.receiving_all_presence_test_body()
 
     @override_config(
@@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_receiving_all_presence(self):
+    def test_receiving_all_presence(self) -> None:
         self.receiving_all_presence_test_body()
 
-    def receiving_all_presence_test_body(self):
+    def receiving_all_presence_test_body(self) -> None:
         """Test that a user that does not share a room with another other can receive
         presence for them, due to presence routing.
         """
@@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_send_local_online_presence_to_with_module_legacy(self):
+    def test_send_local_online_presence_to_with_module_legacy(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
     @override_config(
@@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_send_local_online_presence_to_with_module(self):
+    def test_send_local_online_presence_to_with_module(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
-    def send_local_online_presence_to_with_module_test_body(self):
+    def send_local_online_presence_to_with_module_test_body(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to a set
         of specified local and remote users, with a custom PresenceRouter module enabled.
         """
@@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
                     continue
 
                 # EDUs can contain multiple presence updates
-                for presence_update in edu["content"]["push"]:
+                for presence_edu in edu["content"]["push"]:
                     # Check for presence updates that contain the user IDs we're after
-                    found_users.add(presence_update["user_id"])
+                    found_users.add(presence_edu["user_id"])
 
                     # Ensure that no offline states are being sent out
-                    self.assertNotEqual(presence_update["presence"], "offline")
+                    self.assertNotEqual(presence_edu["presence"], "offline")
 
         self.assertEqual(found_users, expected_users)
 
 
 def send_presence_update(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     access_token: str,
     presence_state: str,
@@ -479,7 +491,7 @@ def send_presence_update(
 
 
 def sync_presence(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     since_token: Optional[StreamToken] = None,
 ) -> Tuple[List[UserPresenceState], StreamToken]:
@@ -500,7 +512,7 @@ def sync_presence(
     requester = create_requester(user_id)
     sync_config = generate_sync_config(requester.user.to_string())
     sync_result = testcase.get_success(
-        testcase.sync_handler.wait_for_sync_for_user(
+        testcase.hs.get_sync_handler().wait_for_sync_for_user(
             requester, sync_config, since_token
         )
     )