summary refs log tree commit diff
path: root/tests/handlers/test_presence.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_presence.py217
1 files changed, 192 insertions, 25 deletions
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py

index cc630d606c..6b7bf112c2 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -23,14 +23,21 @@ from typing import Optional, cast from unittest.mock import Mock, call from parameterized import parameterized -from signedjson.key import generate_signing_key +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserDevicePresenceState, UserPresenceState -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events.builder import EventBuilder +from synapse.api.room_versions import ( + RoomVersion, +) +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, make_event_from_dict from synapse.federation.sender import FederationSender from synapse.handlers.presence import ( BUSY_ONLINE_TIMEOUT, @@ -45,18 +52,24 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client import room +from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import override_config class PresenceUpdateTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets] + servlets = [ + admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer @@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): wheel_timer.insert.assert_not_called() + # `rc_presence` is set very high during unit tests to avoid ratelimiting + # subtly impacting unrelated tests. We set the ratelimiting back to a + # reasonable value for the tests specific to presence ratelimiting. + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, immediately send another one and + check that it was ignored. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, advancing time a sufficient amount, + send another presence update and check that it also worked. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def _test_ratelimit_offline_to_online_to_unavailable( + self, ratelimited: bool + ) -> None: + """Test rate limit for presence updates sent with sync requests. + + Args: + ratelimited: Test rate limited case. + """ + wheel_timer = Mock() + user_id = "@user:pass" + now = 5000000 + sync_url = "/sync?access_token=%s&set_presence=%s" + + # Register the user who syncs presence + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Get the handler (which kicks off a bunch of timers). + presence_handler = self.hs.get_presence_handler() + + # Ensure the user is initially offline. + prev_state = UserPresenceState.default(user_id) + new_state = prev_state.copy_and_replace( + state=PresenceState.OFFLINE, last_active_ts=now + ) + + state, persist_and_notify, federation_ping = handle_update( + prev_state, + new_state, + is_mine=True, + wheel_timer=wheel_timer, + now=now, + persist=False, + ) + + # Check that the user is offline. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + # Send sync request with set_presence=online. + channel = self.make_request("GET", sync_url % (access_token, "online")) + self.assertEqual(200, channel.code) + + # Assert the user is now online. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + if not ratelimited: + # Advance time a sufficient amount to avoid rate limiting. + self.reactor.advance(30) + + # Send another sync request with set_presence=unavailable. + channel = self.make_request("GET", sync_url % (access_token, "unavailable")) + self.assertEqual(200, channel.code) + + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + + if ratelimited: + # Assert the user is still online and presence update was ignored. + self.assertEqual(state.state, PresenceState.ONLINE) + else: + # Assert the user is now unavailable. + self.assertEqual(state.state, PresenceState.UNAVAILABLE) + class PresenceTimeoutTestCase(unittest.TestCase): """Tests different timers and that the timer does not change `status_msg` of user.""" @@ -1107,7 +1216,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_syncing_multi_device( @@ -1343,7 +1454,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_non_syncing_multi_device( @@ -1821,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2.hostname = "test2" self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -1936,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version_id(room_id)) + room_version = self.get_success(self.store.get_room_version(room_id)) - builder = EventBuilder( - state=self.state, - event_auth_handler=self._event_auth_handler, - store=self.store, - clock=self.clock, - hostname=hostname, - signing_key=self.random_signing_key, - room_version=KNOWN_ROOM_VERSIONS[room_version], - room_id=room_id, - type=EventTypes.Member, - sender=user_id, - state_key=user_id, - content={"membership": Membership.JOIN}, + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id) ) - prev_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(room_id) + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id) ) - event = self.get_success( - builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) + event = self.create_fake_event_from_remote_server( + remote_server_name=hostname, + event_dict={ + "room_id": room_id, + "sender": user_id, + "type": EventTypes.Member, + "state_key": user_id, + "depth": 1000, + "origin_server_ts": 1, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + }, + room_version=room_version, ) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) @@ -1966,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) self.get_success(self.store.get_event(event.event_id)) + + def create_fake_event_from_remote_server( + self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion + ) -> EventBase: + """ + This is similar to what `FederatingHomeserverTestCase` is doing but we don't + need all of the extra baggage and we want to be able to create an event from + many remote servers. + """ + + # poke the other server's signing key into the key store, so that we don't + # make requests for it + other_server_signature_key = generate_signing_key("test") + verify_key = get_verify_key(other_server_signature_key) + verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + self.get_success( + self.hs.get_datastores().main.store_server_keys_response( + remote_server_name, + from_server=remote_server_name, + ts_added_ms=self.clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, + valid_until_ts=self.clock.time_msec() + 10000, + ), + }, + response_json={ + "verify_keys": { + verify_key_id: {"key": encode_verify_key_base64(verify_key)} + } + }, + ) + ) + + add_hashes_and_signatures( + room_version=room_version, + event_dict=event_dict, + signature_name=remote_server_name, + signing_key=other_server_signature_key, + ) + event = make_event_from_dict( + event_dict, + room_version=room_version, + ) + + return event