summary refs log tree commit diff
path: root/synapse/handlers/delayed_events.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/delayed_events.py545
1 files changed, 545 insertions, 0 deletions
diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py
new file mode 100644

index 0000000000..cb2a34ff73 --- /dev/null +++ b/synapse/handlers/delayed_events.py
@@ -0,0 +1,545 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, List, Optional, Set, Tuple + +from twisted.internet.interfaces import IDelayedCall + +from synapse.api.constants import EventTypes +from synapse.api.errors import ShadowBanError +from synapse.api.ratelimiting import Ratelimiter +from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME +from synapse.logging.opentracing import set_tag +from synapse.metrics import event_processing_positions +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.delayed_events import ( + ReplicationAddedDelayedEventRestServlet, +) +from synapse.storage.databases.main.delayed_events import ( + DelayedEventDetails, + DelayID, + EventType, + StateKey, + Timestamp, + UserLocalpart, +) +from synapse.storage.databases.main.state_deltas import StateDelta +from synapse.types import ( + JsonDict, + Requester, + RoomID, + UserID, + create_requester, +) +from synapse.util.events import generate_fake_event_id +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class DelayedEventsHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self._config = hs.config + self._clock = hs.get_clock() + self._event_creation_handler = hs.get_event_creation_handler() + self._room_member_handler = hs.get_room_member_handler() + + self._request_ratelimiter = hs.get_request_ratelimiter() + + # Ratelimiter for management of existing delayed events, + # keyed by the sending user ID & device ID. + self._delayed_event_mgmt_ratelimiter = Ratelimiter( + store=self._store, + clock=self._clock, + cfg=self._config.ratelimiting.rc_delayed_event_mgmt, + ) + + self._next_delayed_event_call: Optional[IDelayedCall] = None + + # The current position in the current_state_delta stream + self._event_pos: Optional[int] = None + + # Guard to ensure we only process event deltas one at a time + self._event_processing = False + + if hs.config.worker.worker_app is None: + self._repl_client = None + + async def _schedule_db_events() -> None: + # We kick this off to pick up outstanding work from before the last restart. + # Block until we're up to date. + await self._unsafe_process_new_event() + hs.get_notifier().add_replication_callback(self.notify_new_event) + # Kick off again (without blocking) to catch any missed notifications + # that may have fired before the callback was added. + self._clock.call_later(0, self.notify_new_event) + + # Delayed events that are already marked as processed on startup might not have been + # sent properly on the last run of the server, so unmark them to send them again. + # Caveat: this will double-send delayed events that successfully persisted, but failed + # to be removed from the DB table of delayed events. + # TODO: To avoid double-sending, scan the timeline to find which of these events were + # already sent. To do so, must store delay_ids in sent events to retrieve them later. + await self._store.unprocess_delayed_events() + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + # Can send the events in background after having awaited on marking them as processed + run_as_background_process( + "_send_events", + self._send_events, + events, + ) + + self._initialized_from_db = run_as_background_process( + "_schedule_db_events", _schedule_db_events + ) + else: + self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) + + @property + def _is_master(self) -> bool: + return self._repl_client is None + + def notify_new_event(self) -> None: + """ + Called when there may be more state event deltas to process, + which should cancel pending delayed events for the same state. + """ + if self._event_processing: + return + + self._event_processing = True + + async def process() -> None: + try: + await self._unsafe_process_new_event() + finally: + self._event_processing = False + + run_as_background_process("delayed_events.notify_new_event", process) + + async def _unsafe_process_new_event(self) -> None: + # If self._event_pos is None then means we haven't fetched it from the DB yet + if self._event_pos is None: + self._event_pos = await self._store.get_delayed_events_stream_pos() + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos > room_max_stream_ordering: + # apparently, we've processed more events than exist in the database! + # this can happen if events are removed with history purge or similar. + logger.warning( + "Event stream ordering appears to have gone backwards (%i -> %i): " + "rewinding delayed events processor", + self._event_pos, + room_max_stream_ordering, + ) + self._event_pos = room_max_stream_ordering + + # Loop round handling deltas until we're up to date + while True: + with Measure(self._clock, "delayed_events_delta"): + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: + return + + logger.debug( + "Processing delayed events %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) + + logger.debug( + "Handling %d state deltas for delayed events processing", + len(deltas), + ) + await self._handle_state_deltas(deltas) + + self._event_pos = max_pos + + # Expose current event processing position to prometheus + event_processing_positions.labels("delayed_events").set(max_pos) + + await self._store.update_delayed_events_stream_pos(max_pos) + + async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: + """ + Process current state deltas to cancel other users' pending delayed events + that target the same state. + """ + for delta in deltas: + if delta.event_id is None: + logger.debug( + "Not handling delta for deleted state: %r %r", + delta.event_type, + delta.state_key, + ) + continue + + logger.debug( + "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id + ) + + event = await self._store.get_event( + delta.event_id, check_room_id=delta.room_id, allow_rejected=True, allow_none=True + ) + + if event is None or event.rejected_reason is not None: + # This event has been rejected, so we don't want to cancel any delayed events for it. + continue + + sender = UserID.from_string(event.sender) + + next_send_ts = await self._store.cancel_delayed_state_events( + room_id=delta.room_id, + event_type=delta.event_type, + state_key=delta.state_key, + not_from_localpart=( + sender.localpart + if sender.domain == self._config.server.server_name + else "" + ), + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + async def add( + self, + requester: Requester, + *, + room_id: str, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: int, + ) -> str: + """ + Creates a new delayed event and schedules its delivery. + + Args: + requester: The requester of the delayed event, who will be its owner. + room_id: The ID of the room where the event should be sent to. + event_type: The type of event to be sent. + state_key: The state key of the event to be sent, or None if it is not a state event. + origin_server_ts: The custom timestamp to send the event with. + If None, the timestamp will be the actual time when the event is sent. + content: The content of the event to be sent. + delay: How long (in milliseconds) to wait before automatically sending the event. + + Returns: The ID of the added delayed event. + + Raises: + SynapseError: if the delayed event fails validation checks. + """ + # Use standard request limiter for scheduling new delayed events. + # TODO: Instead apply ratelimiting based on the scheduled send time. + # See https://github.com/element-hq/synapse/issues/18021 + await self._request_ratelimiter.ratelimit(requester) + + self._event_creation_handler.validator.validate_builder( + self._event_creation_handler.event_builder_factory.for_room_version( + await self._store.get_room_version(room_id), + { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": str(requester.user), + **({"state_key": state_key} if state_key is not None else {}), + }, + ) + ) + + creation_ts = self._get_current_ts() + + delay_id, next_send_ts = await self._store.add_delayed_event( + user_localpart=requester.user.localpart, + device_id=requester.device_id, + creation_ts=creation_ts, + room_id=room_id, + event_type=event_type, + state_key=state_key, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + ) + + if self._repl_client is not None: + # NOTE: If this throws, the delayed event will remain in the DB and + # will be picked up once the main worker gets another delayed event. + await self._repl_client( + instance_name=MAIN_PROCESS_INSTANCE_NAME, + next_send_ts=next_send_ts, + ) + elif self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + return delay_id + + def on_added(self, next_send_ts: int) -> None: + next_send_ts = Timestamp(next_send_ts) + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + async def cancel(self, requester: Requester, delay_id: str) -> None: + """ + Cancels the scheduled delivery of the matching delayed event. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + await self._initialized_from_db + + next_send_ts = await self._store.cancel_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + async def restart(self, requester: Requester, delay_id: str) -> None: + """ + Restarts the scheduled delivery of the matching delayed event. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + await self._initialized_from_db + + next_send_ts = await self._store.restart_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + current_ts=self._get_current_ts(), + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + async def send(self, requester: Requester, delay_id: str) -> None: + """ + Immediately sends the matching delayed event, instead of waiting for its scheduled delivery. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + # Use standard request limiter for sending delayed events on-demand, + # as an on-demand send is similar to sending a regular event. + await self._request_ratelimiter.ratelimit(requester) + await self._initialized_from_db + + event, next_send_ts = await self._store.process_target_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + await self._send_event( + DelayedEventDetails( + delay_id=DelayID(delay_id), + user_localpart=UserLocalpart(requester.user.localpart), + room_id=event.room_id, + type=event.type, + state_key=event.state_key, + origin_server_ts=event.origin_server_ts, + content=event.content, + device_id=event.device_id, + ) + ) + + async def _send_on_timeout(self) -> None: + self._next_delayed_event_call = None + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + await self._send_events(events) + + async def _send_events(self, events: List[DelayedEventDetails]) -> None: + sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() + for event in events: + if event.state_key is not None: + state_info = (event.room_id, event.type, event.state_key) + if state_info in sent_state: + continue + else: + state_info = None + try: + # TODO: send in background if message event or non-conflicting state event + await self._send_event(event) + if state_info is not None: + sent_state.add(state_info) + except Exception: + logger.exception("Failed to send delayed event") + + for room_id, event_type, state_key in sent_state: + await self._store.delete_processed_delayed_state_events( + room_id=str(room_id), + event_type=event_type, + state_key=state_key, + ) + + def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: + if next_send_ts is not None: + self._schedule_next_at(next_send_ts) + elif self._next_delayed_event_call is not None: + self._next_delayed_event_call.cancel() + self._next_delayed_event_call = None + + def _schedule_next_at(self, next_send_ts: Timestamp) -> None: + delay = next_send_ts - self._get_current_ts() + delay_sec = delay / 1000 if delay > 0 else 0 + + if self._next_delayed_event_call is None: + self._next_delayed_event_call = self._clock.call_later( + delay_sec, + run_as_background_process, + "_send_on_timeout", + self._send_on_timeout, + ) + else: + self._next_delayed_event_call.reset(delay_sec) + + async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: + """Return all pending delayed events requested by the given user.""" + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + return await self._store.get_all_delayed_events_for_user( + requester.user.localpart + ) + + async def _send_event( + self, + event: DelayedEventDetails, + txn_id: Optional[str] = None, + ) -> None: + user_id = UserID(event.user_localpart, self._config.server.server_name) + user_id_str = user_id.to_string() + # Create a new requester from what data is currently available + requester = create_requester( + user_id, + is_guest=await self._store.is_guest(user_id_str), + device_id=event.device_id, + ) + + try: + if event.state_key is not None and event.type == EventTypes.Member: + membership = event.content.get("membership") + assert membership is not None + event_id, _ = await self._room_member_handler.update_membership( + requester, + target=UserID.from_string(event.state_key), + room_id=event.room_id.to_string(), + action=membership, + content=event.content, + origin_server_ts=event.origin_server_ts, + ) + else: + event_dict: JsonDict = { + "type": event.type, + "content": event.content, + "room_id": event.room_id.to_string(), + "sender": user_id_str, + } + + if event.origin_server_ts is not None: + event_dict["origin_server_ts"] = event.origin_server_ts + + if event.state_key is not None: + event_dict["state_key"] = event.state_key + + ( + sent_event, + _, + ) = await self._event_creation_handler.create_and_send_nonmember_event( + requester, + event_dict, + txn_id=txn_id, + ) + event_id = sent_event.event_id + except ShadowBanError: + event_id = generate_fake_event_id() + finally: + # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure + try: + await self._store.delete_processed_delayed_event( + event.delay_id, event.user_localpart + ) + except Exception: + logger.exception("Failed to delete processed delayed event") + + set_tag("event_id", event_id) + + def _get_current_ts(self) -> Timestamp: + return Timestamp(self._clock.time_msec()) + + def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool: + # The DB alone knows if the next send time changed after adding/modifying + # a delayed event, but if we were to ever miss updating our delayed call's + # firing time, we may miss other updates. So, keep track of changes to the + # the next send time here instead of in the DB. + cached_next_send_ts = ( + int(self._next_delayed_event_call.getTime() * 1000) + if self._next_delayed_event_call is not None + else None + ) + return next_send_ts != cached_next_send_ts