summary refs log tree commit diff
path: root/tests/handlers/test_presence.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_presence.py')
-rw-r--r--tests/handlers/test_presence.py100
1 files changed, 54 insertions, 46 deletions
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 584e7b8971..19f5322317 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -12,12 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Optional, cast
 from unittest.mock import Mock, call
 
 from parameterized import parameterized
 from signedjson.key import generate_signing_key
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -35,7 +37,9 @@ from synapse.handlers.presence import (
 )
 from synapse.rest import admin
 from synapse.rest.client import room
-from synapse.types import UserID, get_domain_from_id
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.util import Clock
 
 from tests import unittest
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
 class PresenceUpdateTestCase(unittest.HomeserverTestCase):
     servlets = [admin.register_servlets]
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.store = homeserver.get_datastores().main
 
-    def test_offline_to_online(self):
+    def test_offline_to_online(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_online_to_online(self):
+    def test_online_to_online(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_online_to_online_last_active_noop(self):
+    def test_online_to_online_last_active_noop(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_online_to_online_last_active(self):
+    def test_online_to_online_last_active(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_remote_ping_timer(self):
+    def test_remote_ping_timer(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_online_to_offline(self):
+    def test_online_to_offline(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(wheel_timer.insert.call_count, 0)
 
-    def test_online_to_idle(self):
+    def test_online_to_idle(self) -> None:
         wheel_timer = Mock()
         user_id = "@foo:bar"
         now = 5000000
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
             any_order=True,
         )
 
-    def test_persisting_presence_updates(self):
+    def test_persisting_presence_updates(self) -> None:
         """Tests that the latest presence state for each user is persisted correctly"""
         # Create some test users and presence states for them
         presence_states = []
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.update_presence(presence_states))
 
         # Check that each update is present in the database
-        db_presence_states = self.get_success(
+        db_presence_states_raw = self.get_success(
             self.store.get_all_presence_updates(
                 instance_name="master",
                 last_id=0,
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # Extract presence update user ID and state information into lists of tuples
-        db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
+        db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
         presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
 
         # Compare what we put into the storage with what we got out.
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 class PresenceTimeoutTestCase(unittest.TestCase):
     """Tests different timers and that the timer does not change `status_msg` of user."""
 
-    def test_idle_timer(self):
+    def test_idle_timer(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
         self.assertEqual(new_state.status_msg, status_msg)
 
-    def test_busy_no_idle(self):
+    def test_busy_no_idle(self) -> None:
         """
         Tests that a user setting their presence to busy but idling doesn't turn their
         presence state into unavailable.
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEqual(new_state.state, PresenceState.BUSY)
         self.assertEqual(new_state.status_msg, status_msg)
 
-    def test_sync_timeout(self):
+    def test_sync_timeout(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
-    def test_sync_online(self):
+    def test_sync_online(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEqual(new_state.state, PresenceState.ONLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
-    def test_federation_ping(self):
+    def test_federation_ping(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertIsNotNone(new_state)
         self.assertEqual(state, new_state)
 
-    def test_no_timeout(self):
+    def test_no_timeout(self) -> None:
         user_id = "@foo:bar"
         now = 5000000
 
@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
         self.assertIsNone(new_state)
 
-    def test_federation_timeout(self):
+    def test_federation_timeout(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
-    def test_last_active(self):
+    def test_last_active(self) -> None:
         user_id = "@foo:bar"
         status_msg = "I'm here!"
         now = 5000000
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
 
 class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
 
-    def test_external_process_timeout(self):
+    def test_external_process_timeout(self) -> None:
         """Test that if an external process doesn't update the records for a while
         we time out their syncing users presence.
         """
-        process_id = 1
+        process_id = "1"
         user_id = "@test:server"
 
         # Notify handler that a user is now syncing.
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         )
         self.assertEqual(state.state, PresenceState.OFFLINE)
 
-    def test_user_goes_offline_by_timeout_status_msg_remain(self):
+    def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
         """Test that if a user doesn't update the records for a while
         users presence goes `OFFLINE` because of timeout and `status_msg` remains.
         """
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(state.state, PresenceState.OFFLINE)
         self.assertEqual(state.status_msg, status_msg)
 
-    def test_user_goes_offline_manually_with_no_status_msg(self):
+    def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
         """Test that if a user change presence manually to `OFFLINE`
         and no status is set, that `status_msg` is `None`.
         """
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(state.state, PresenceState.OFFLINE)
         self.assertEqual(state.status_msg, None)
 
-    def test_user_goes_offline_manually_with_status_msg(self):
+    def test_user_goes_offline_manually_with_status_msg(self) -> None:
         """Test that if a user change presence manually to `OFFLINE`
         and a status is set, that `status_msg` appears.
         """
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
             user_id, PresenceState.OFFLINE, "And now here."
         )
 
-    def test_user_reset_online_with_no_status(self):
+    def test_user_reset_online_with_no_status(self) -> None:
         """Test that if a user set again the presence manually
         and no status is set, that `status_msg` is `None`.
         """
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(state.state, PresenceState.ONLINE)
         self.assertEqual(state.status_msg, None)
 
-    def test_set_presence_with_status_msg_none(self):
+    def test_set_presence_with_status_msg_none(self) -> None:
         """Test that if a user set again the presence manually
         and status is `None`, that `status_msg` is `None`.
         """
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # Mark user as online and `status_msg = None`
         self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
 
-    def test_set_presence_from_syncing_not_set(self):
+    def test_set_presence_from_syncing_not_set(self) -> None:
         """Test that presence is not set by syncing if affect_presence is false"""
         user_id = "@test:server"
         status_msg = "I'm here!"
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # and status message should still be the same
         self.assertEqual(state.status_msg, status_msg)
 
-    def test_set_presence_from_syncing_is_set(self):
+    def test_set_presence_from_syncing_is_set(self) -> None:
         """Test that presence is set by syncing if affect_presence is true"""
         user_id = "@test:server"
         status_msg = "I'm here!"
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # we should now be online
         self.assertEqual(state.state, PresenceState.ONLINE)
 
-    def test_set_presence_from_syncing_keeps_status(self):
+    def test_set_presence_from_syncing_keeps_status(self) -> None:
         """Test that presence set by syncing retains status message"""
         user_id = "@test:server"
         status_msg = "I'm here!"
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
             },
         }
     )
-    def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+    def test_set_presence_from_syncing_keeps_busy(
+        self, test_with_workers: bool
+    ) -> None:
         """Test that presence set by syncing doesn't affect busy status
 
         Args:
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
     def _set_presencestate_with_status_msg(
         self, user_id: str, state: str, status_msg: Optional[str]
-    ):
+    ) -> None:
         """Set a PresenceState and status_msg and check the result.
 
         Args:
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
 
 class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
         self.instance_name = hs.get_instance_name()
 
         self.queue = self.presence_handler.get_federation_queue()
 
-    def test_send_and_get(self):
+    def test_send_and_get(self) -> None:
         state1 = UserPresenceState.default("@user1:test")
         state2 = UserPresenceState.default("@user2:test")
         state3 = UserPresenceState.default("@user3:test")
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
         self.assertFalse(limited)
         self.assertCountEqual(rows, [])
 
-    def test_send_and_get_split(self):
+    def test_send_and_get_split(self) -> None:
         state1 = UserPresenceState.default("@user1:test")
         state2 = UserPresenceState.default("@user2:test")
         state3 = UserPresenceState.default("@user3:test")
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         self.assertCountEqual(rows, expected_rows)
 
-    def test_clear_queue_all(self):
+    def test_clear_queue_all(self) -> None:
         state1 = UserPresenceState.default("@user1:test")
         state2 = UserPresenceState.default("@user2:test")
         state3 = UserPresenceState.default("@user3:test")
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         self.assertCountEqual(rows, expected_rows)
 
-    def test_partially_clear_queue(self):
+    def test_partially_clear_queue(self) -> None:
         state1 = UserPresenceState.default("@user1:test")
         state2 = UserPresenceState.default("@user2:test")
         state3 = UserPresenceState.default("@user3:test")
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     servlets = [room.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(
             "server",
             federation_http_client=None,
@@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
         return hs
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
         # Enable federation sending on the main process.
         config["federation_sender_instances"] = None
         return config
 
-    def prepare(self, reactor, clock, hs):
-        self.federation_sender = hs.get_federation_sender()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.federation_sender = cast(Mock, hs.get_federation_sender())
         self.event_builder_factory = hs.get_event_builder_factory()
         self.federation_event_handler = hs.get_federation_event_handler()
         self.presence_handler = hs.get_presence_handler()
@@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # random key to use.
         self.random_signing_key = generate_signing_key("ver")
 
-    def test_remote_joins(self):
+    def test_remote_joins(self) -> None:
         # We advance time to something that isn't 0, as we use 0 as a special
         # value.
         self.reactor.advance(1000000000000)
@@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
             destinations={"server3"}, states=[expected_state]
         )
 
-    def test_remote_gets_presence_when_local_user_joins(self):
+    def test_remote_gets_presence_when_local_user_joins(self) -> None:
         # We advance time to something that isn't 0, as we use 0 as a special
         # value.
         self.reactor.advance(1000000000000)
@@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
             destinations={"server2", "server3"}, states=[expected_state]
         )
 
-    def _add_new_user(self, room_id, user_id):
+    def _add_new_user(self, room_id: str, user_id: str) -> None:
         """Add new user to the room by creating an event and poking the federation API."""
 
         hostname = get_domain_from_id(user_id)