summary refs log tree commit diff
path: root/synapse/storage/databases/state/deletion.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/state/deletion.py')
-rw-r--r--synapse/storage/databases/state/deletion.py446
1 files changed, 446 insertions, 0 deletions
diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py
new file mode 100644

index 0000000000..07dbbc8e75 --- /dev/null +++ b/synapse/storage/databases/state/deletion.py
@@ -0,0 +1,446 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import contextlib +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Mapping, + Set, + Tuple, +) + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.engines import PostgresEngine +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateDeletionDataStore: + """Manages deletion of state groups in a safe manner. + + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. + + To handle this, we take two approaches. First, before we persist any event + we ensure that the state group still exists and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So instead of deleting state groups immediately, we mark + them as pending/proposed for deletion and wait for a certain amount of time + before performing the deletion. When we come to handle new events that + reference state groups, we check if they are pending deletion and bump the + time for when they'll be deleted (to give a chance for the event to be + persisted, or not). + + When deleting, we need to check that state groups remain unreferenced. There + is a race here where we a) fetch state groups that are ready for deletion, + b) check they're unreferenced, c) the state group becomes referenced but + then gets marked as pending deletion again, d) during the deletion + transaction we recheck `state_groups_pending_deletion` table again and see + that it exists and so continue with the deletion. To prevent this from + happening we add a `sequence_number` column to + `state_groups_pending_deletion`, and during deletion we ensure that for a + state group we're about to delete that the sequence number doesn't change + between steps (a) and (d). So long as we always bump the sequence number + whenever an event may become used the race can never happen. + """ + + # How long to wait before we delete state groups. This should be long enough + # for any in-flight events to be persisted. If events take longer to persist + # and any of the state groups they reference have been deleted, then the + # event will fail to persist (as well as any event in the same batch). + DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000 + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + self._instance_name = hs.get_instance_name() + + # TODO: Clear from `state_groups_persisting` any holdovers from previous + # running instance. + + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + Returns: + The state groups that are missing, if any. + """ + + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, + state_groups, + ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int] + ) -> Collection[int]: + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + + self._bump_deletion_txn(txn, existing_state_groups) + + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + return missing_state_groups + + return () + + def _bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> None: + """Update any pending deletions of the state group that they may now be + referenced.""" + + if not state_groups: + return + + now = self._clock.time_msec() + if isinstance(self.db_pool.engine, PostgresEngine): + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) + sql = f""" + UPDATE state_groups_pending_deletion + SET sequence_number = DEFAULT, insertion_ts = ? + WHERE {clause} + """ + args.insert(0, now) + txn.execute(sql, args) + else: + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group",), + ) + if not rows: + return + + state_groups_to_update = [state_group for (state_group,) in rows] + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + values=state_groups_to_update, + keyvalues={}, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group", "insertion_ts"), + values=[(state_group, now) for state_group in state_groups_to_update], + ) + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted.""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we add a row level lock to the rows to ensure that we + # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will + # not conflict with other read + sql += """ + FOR KEY SHARE + """ + + txn.execute(sql, args) + return {state_group for (state_group,) in txn} + + @contextlib.asynccontextmanager + async def persisting_state_group_references( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + + referenced_state_groups: Set[int] = set() + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_group is not None + + referenced_state_groups.add(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.add(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + yield + return + + await self.db_pool.runInteraction( + "mark_state_groups_as_persisting", + self._mark_state_groups_as_persisting_txn, + referenced_state_groups, + ) + + error = True + try: + yield None + error = False + finally: + await self.db_pool.runInteraction( + "finish_persisting", + self._finish_persisting_txn, + referenced_state_groups, + error=error, + ) + + def _mark_state_groups_as_persisting_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as being persisted.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def _finish_persisting_txn( + self, txn: LoggingTransaction, state_groups: Collection[int], error: bool + ) -> None: + """Mark the state groups as having finished persistence. + + If `error` is true then we assume the state groups were not persisted, + and so we do not clear them from the pending deletion table. + """ + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_persisting", + column="state_group", + values=state_groups, + keyvalues={"instance_name": self._instance_name}, + ) + + if error: + # The state groups may or may not have been persisted, so we need to + # bump the deletion to ensure we recheck if they have become + # referenced. + self._bump_deletion_txn(txn, state_groups) + return + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + async def mark_state_groups_as_pending_deletion( + self, state_groups: Collection[int] + ) -> None: + """Mark the given state groups as pending deletion""" + + now = self._clock.time_msec() + + await self.db_pool.simple_upsert_many( + table="state_groups_pending_deletion", + key_names=("state_group",), + key_values=[(state_group,) for state_group in state_groups], + value_names=("insertion_ts",), + value_values=[(now,) for _ in state_groups], + desc="mark_state_groups_as_pending_deletion", + ) + + async def get_pending_deletions( + self, state_groups: Collection[int] + ) -> Mapping[int, int]: + """Get which state groups are pending deletion. + + Returns: + a mapping from state groups that are pending deletion to their + sequence number + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + retcols=("state_group", "sequence_number"), + keyvalues={}, + desc="get_pending_deletions", + ) + + return dict(rows) + + def get_state_groups_ready_for_potential_deletion_txn( + self, + txn: LoggingTransaction, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> Collection[int]: + """Given a set of state groups, return which state groups can + potentially be deleted. + + The state groups must have been checked to see if they remain + unreferenced before calling this function. + + Note: This must be called within the same transaction that the state + groups are deleted. + + Args: + state_groups_to_sequence_numbers: The state groups, and the sequence + numbers from before the state groups were checked to see if they + were unreferenced. + + Returns: + The subset of state groups that can safely be deleted + + """ + + if not state_groups_to_sequence_numbers: + return state_groups_to_sequence_numbers + + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups_to_sequence_numbers + ) + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + + # Check the deletion status in the DB of the given state groups + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + + sql = f""" + SELECT state_group, insertion_ts, sequence_number FROM ( + SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion + UNION + SELECT state_group, null, null FROM state_groups_persisting + ) AS s + WHERE {clause} + """ + + txn.execute(sql, args) + + # The above query will return potentially two rows per state group (one + # for each table), so we track which state groups have enough time + # elapsed and which are not ready to be persisted. + ready_to_be_deleted = set() + not_ready_to_be_deleted = set() + + now = self._clock.time_msec() + for state_group, insertion_ts, sequence_number in txn: + if insertion_ts is None: + # A null insertion_ts means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + not_ready_to_be_deleted.add(state_group) + continue + + # We know this can't be None if insertion_ts is not None + assert sequence_number is not None + + # Check if the sequence number has changed, if it has then it + # indicates that the state group may have become referenced since we + # checked. + if state_groups_to_sequence_numbers[state_group] != sequence_number: + not_ready_to_be_deleted.add(state_group) + continue + + if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS: + # Not enough time has elapsed to allow us to delete. + not_ready_to_be_deleted.add(state_group) + continue + + ready_to_be_deleted.add(state_group) + + can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted + if not_ready_to_be_deleted: + # If there are any state groups that aren't ready to be deleted, + # then we also need to remove any state groups that are referenced + # by them. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + sql = f""" + WITH RECURSIVE ancestors(state_group) AS ( + SELECT DISTINCT prev_state_group + FROM state_group_edges WHERE {clause} + UNION + SELECT prev_state_group + FROM state_group_edges + INNER JOIN ancestors USING (state_group) + ) + SELECT state_group FROM ancestors + """ + txn.execute(sql, args) + + can_be_deleted.difference_update(state_group for (state_group,) in txn) + + return can_be_deleted