diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..6b7bf112c2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -23,14 +23,21 @@ from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
-from signedjson.key import generate_signing_key
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserDevicePresenceState, UserPresenceState
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.events.builder import EventBuilder
+from synapse.api.room_versions import (
+ RoomVersion,
+)
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import EventBase, make_event_from_dict
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
BUSY_ONLINE_TIMEOUT,
@@ -45,18 +52,24 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest import admin
-from synapse.rest.client import room
+from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import override_config
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
- servlets = [admin.register_servlets]
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
@@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
wheel_timer.insert.assert_not_called()
+ # `rc_presence` is set very high during unit tests to avoid ratelimiting
+ # subtly impacting unrelated tests. We set the ratelimiting back to a
+ # reasonable value for the tests specific to presence ratelimiting.
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, immediately send another one and
+ check that it was ignored.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, advancing time a sufficient amount,
+ send another presence update and check that it also worked.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def _test_ratelimit_offline_to_online_to_unavailable(
+ self, ratelimited: bool
+ ) -> None:
+ """Test rate limit for presence updates sent with sync requests.
+
+ Args:
+ ratelimited: Test rate limited case.
+ """
+ wheel_timer = Mock()
+ user_id = "@user:pass"
+ now = 5000000
+ sync_url = "/sync?access_token=%s&set_presence=%s"
+
+ # Register the user who syncs presence
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Get the handler (which kicks off a bunch of timers).
+ presence_handler = self.hs.get_presence_handler()
+
+ # Ensure the user is initially offline.
+ prev_state = UserPresenceState.default(user_id)
+ new_state = prev_state.copy_and_replace(
+ state=PresenceState.OFFLINE, last_active_ts=now
+ )
+
+ state, persist_and_notify, federation_ping = handle_update(
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
+ )
+
+ # Check that the user is offline.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
+ # Send sync request with set_presence=online.
+ channel = self.make_request("GET", sync_url % (access_token, "online"))
+ self.assertEqual(200, channel.code)
+
+ # Assert the user is now online.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ if not ratelimited:
+ # Advance time a sufficient amount to avoid rate limiting.
+ self.reactor.advance(30)
+
+ # Send another sync request with set_presence=unavailable.
+ channel = self.make_request("GET", sync_url % (access_token, "unavailable"))
+ self.assertEqual(200, channel.code)
+
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+
+ if ratelimited:
+ # Assert the user is still online and presence update was ignored.
+ self.assertEqual(state.state, PresenceState.ONLINE)
+ else:
+ # Assert the user is now unavailable.
+ self.assertEqual(state.state, PresenceState.UNAVAILABLE)
+
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
@@ -1107,7 +1216,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_syncing_multi_device(
@@ -1343,7 +1454,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_non_syncing_multi_device(
@@ -1821,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# self.event_builder_for_2.hostname = "test2"
self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -1936,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
hostname = get_domain_from_id(user_id)
- room_version = self.get_success(self.store.get_room_version_id(room_id))
+ room_version = self.get_success(self.store.get_room_version(room_id))
- builder = EventBuilder(
- state=self.state,
- event_auth_handler=self._event_auth_handler,
- store=self.store,
- clock=self.clock,
- hostname=hostname,
- signing_key=self.random_signing_key,
- room_version=KNOWN_ROOM_VERSIONS[room_version],
- room_id=room_id,
- type=EventTypes.Member,
- sender=user_id,
- state_key=user_id,
- content={"membership": Membership.JOIN},
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
)
- prev_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(room_id)
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ forward_extremity_event_ids = self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(
- builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
+ event = self.create_fake_event_from_remote_server(
+ remote_server_name=hostname,
+ event_dict={
+ "room_id": room_id,
+ "sender": user_id,
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.JoinRules, "")].event_id,
+ ],
+ "prev_events": list(forward_extremity_event_ids),
+ },
+ room_version=room_version,
)
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
@@ -1966,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
self.get_success(self.store.get_event(event.event_id))
+
+ def create_fake_event_from_remote_server(
+ self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
+ ) -> EventBase:
+ """
+ This is similar to what `FederatingHomeserverTestCase` is doing but we don't
+ need all of the extra baggage and we want to be able to create an event from
+ many remote servers.
+ """
+
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ other_server_signature_key = generate_signing_key("test")
+ verify_key = get_verify_key(other_server_signature_key)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ self.get_success(
+ self.hs.get_datastores().main.store_server_keys_response(
+ remote_server_name,
+ from_server=remote_server_name,
+ ts_added_ms=self.clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=self.clock.time_msec() + 10000,
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {"key": encode_verify_key_base64(verify_key)}
+ }
+ },
+ )
+ )
+
+ add_hashes_and_signatures(
+ room_version=room_version,
+ event_dict=event_dict,
+ signature_name=remote_server_name,
+ signing_key=other_server_signature_key,
+ )
+ event = make_event_from_dict(
+ event_dict,
+ room_version=room_version,
+ )
+
+ return event
|