diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 584e7b8971..19f5322317 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
from signedjson.key import generate_signing_key
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -35,7 +37,9 @@ from synapse.handlers.presence import (
)
from synapse.rest import admin
from synapse.rest.client import room
-from synapse.types import UserID, get_domain_from_id
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_offline_to_online(self):
+ def test_offline_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online(self):
+ def test_online_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active_noop(self):
+ def test_online_to_online_last_active_noop(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active(self):
+ def test_online_to_online_last_active(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_remote_ping_timer(self):
+ def test_remote_ping_timer(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_offline(self):
+ def test_online_to_offline(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.assertEqual(wheel_timer.insert.call_count, 0)
- def test_online_to_idle(self):
+ def test_online_to_idle(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_persisting_presence_updates(self):
+ def test_persisting_presence_updates(self) -> None:
"""Tests that the latest presence state for each user is persisted correctly"""
# Create some test users and presence states for them
presence_states = []
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.update_presence(presence_states))
# Check that each update is present in the database
- db_presence_states = self.get_success(
+ db_presence_states_raw = self.get_success(
self.store.get_all_presence_updates(
instance_name="master",
last_id=0,
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
# Extract presence update user ID and state information into lists of tuples
- db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
+ db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
- def test_idle_timer(self):
+ def test_idle_timer(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_busy_no_idle(self):
+ def test_busy_no_idle(self) -> None:
"""
Tests that a user setting their presence to busy but idling doesn't turn their
presence state into unavailable.
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_timeout(self):
+ def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_online(self):
+ def test_sync_online(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_federation_ping(self):
+ def test_federation_ping(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
- def test_no_timeout(self):
+ def test_no_timeout(self) -> None:
user_id = "@foo:bar"
now = 5000000
@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNone(new_state)
- def test_federation_timeout(self):
+ def test_federation_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_last_active(self):
+ def test_last_active(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
- def test_external_process_timeout(self):
+ def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while
we time out their syncing users presence.
"""
- process_id = 1
+ process_id = "1"
user_id = "@test:server"
# Notify handler that a user is now syncing.
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertEqual(state.state, PresenceState.OFFLINE)
- def test_user_goes_offline_by_timeout_status_msg_remain(self):
+ def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
"""
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg)
- def test_user_goes_offline_manually_with_no_status_msg(self):
+ def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`.
"""
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None)
- def test_user_goes_offline_manually_with_status_msg(self):
+ def test_user_goes_offline_manually_with_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears.
"""
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id, PresenceState.OFFLINE, "And now here."
)
- def test_user_reset_online_with_no_status(self):
+ def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`.
"""
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None)
- def test_set_presence_with_status_msg_none(self):
+ def test_set_presence_with_status_msg_none(self) -> None:
"""Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`.
"""
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
- def test_set_presence_from_syncing_not_set(self):
+ def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# and status message should still be the same
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_is_set(self):
+ def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
- def test_set_presence_from_syncing_keeps_status(self):
+ def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
},
}
)
- def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ def test_set_presence_from_syncing_keeps_busy(
+ self, test_with_workers: bool
+ ) -> None:
"""Test that presence set by syncing doesn't affect busy status
Args:
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str]
- ):
+ ) -> None:
"""Set a PresenceState and status_msg and check the result.
Args:
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name()
self.queue = self.presence_handler.get_federation_queue()
- def test_send_and_get(self):
+ def test_send_and_get(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertFalse(limited)
self.assertCountEqual(rows, [])
- def test_send_and_get_split(self):
+ def test_send_and_get_split(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_clear_queue_all(self):
+ def test_clear_queue_all(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_partially_clear_queue(self):
+ def test_partially_clear_queue(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
"server",
federation_http_client=None,
@@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
# Enable federation sending on the main process.
config["federation_sender_instances"] = None
return config
- def prepare(self, reactor, clock, hs):
- self.federation_sender = hs.get_federation_sender()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.federation_sender = cast(Mock, hs.get_federation_sender())
self.event_builder_factory = hs.get_event_builder_factory()
self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
@@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# random key to use.
self.random_signing_key = generate_signing_key("ver")
- def test_remote_joins(self):
+ def test_remote_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server3"}, states=[expected_state]
)
- def test_remote_gets_presence_when_local_user_joins(self):
+ def test_remote_gets_presence_when_local_user_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server2", "server3"}, states=[expected_state]
)
- def _add_new_user(self, room_id, user_id):
+ def _add_new_user(self, room_id: str, user_id: str) -> None:
"""Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id)
|