summary refs log tree commit diff
path: root/tests/replication/test_pusher_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/test_pusher_shard.py')
-rw-r--r--tests/replication/test_pusher_shard.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ca18ad6553..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -15,9 +15,12 @@ import logging
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 
@@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register a user who sends a message that we'll get notified about
         self.other_user_id = self.register_user("otheruser", "pass")
         self.other_access_token = self.login("otheruser", "pass")
 
-    def _create_pusher_and_send_msg(self, localpart):
+    def _create_pusher_and_send_msg(self, localpart: str) -> str:
         # Create a user that will get push notifications
         user_id = self.register_user(localpart, "pass")
         access_token = self.login(localpart, "pass")
@@ -47,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         user_dict = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_dict is not None
         token_id = user_dict.token_id
 
         self.get_success(
@@ -79,7 +83,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return event_id
 
-    def test_send_push_single_worker(self):
+    def test_send_push_single_worker(self) -> None:
         """Test that registration works when using a pusher worker."""
         http_client_mock = Mock(spec_set=["post_json_get_json"])
         http_client_mock.post_json_get_json.side_effect = (
@@ -109,7 +113,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
             ],
         )
 
-    def test_send_push_multiple_workers(self):
+    def test_send_push_multiple_workers(self) -> None:
         """Test that registration works when using sharded pusher workers."""
         http_client_mock1 = Mock(spec_set=["post_json_get_json"])
         http_client_mock1.post_json_get_json.side_effect = (