diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index ea7d8199a7..5b594fe8dd 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -20,7 +20,15 @@
#
import logging
-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
from synapse.logging.opentracing import tag_args, trace
from synapse.storage._base import SQLBaseStore
@@ -112,8 +120,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
Returns:
Map from state_group to a StateMap at that point.
"""
-
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
@@ -388,8 +396,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return True, count
txn.execute(
- "SELECT state_group FROM state_group_edges"
- " WHERE state_group = ?",
+ "SELECT state_group FROM state_group_edges WHERE state_group = ?",
(state_group,),
)
diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py
new file mode 100644
index 0000000000..f77c46f6ae
--- /dev/null
+++ b/synapse/storage/databases/state/deletion.py
@@ -0,0 +1,561 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+
+import contextlib
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ AsyncIterator,
+ Collection,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
+
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.storage.engines import PostgresEngine
+from synapse.util.stringutils import shortstr
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class StateDeletionDataStore:
+ """Manages deletion of state groups in a safe manner.
+
+ Deleting state groups is challenging as before we actually delete them we
+ need to ensure that there are no in-flight events that refer to the state
+ groups that we want to delete.
+
+ To handle this, we take two approaches. First, before we persist any event
+ we ensure that the state group still exists and mark in the
+ `state_groups_persisting` table that the state group is about to be used.
+ (Note that we have to have the extra table here as state groups and events
+ can be in different databases, and thus we can't check for the existence of
+ state groups in the persist event transaction). Once the event has been
+ persisted, we can remove the row from `state_groups_persisting`. So long as
+ we check that table before deleting state groups, we can ensure that we
+ never persist events that reference deleted state groups, maintaining
+ database integrity.
+
+ However, we want to avoid throwing exceptions so deep in the process of
+ persisting events. So instead of deleting state groups immediately, we mark
+ them as pending/proposed for deletion and wait for a certain amount of time
+ before performing the deletion. When we come to handle new events that
+ reference state groups, we check if they are pending deletion and bump the
+ time for when they'll be deleted (to give a chance for the event to be
+ persisted, or not).
+
+ When deleting, we need to check that state groups remain unreferenced. There
+ is a race here where we a) fetch state groups that are ready for deletion,
+ b) check they're unreferenced, c) the state group becomes referenced but
+ then gets marked as pending deletion again, d) during the deletion
+ transaction we recheck `state_groups_pending_deletion` table again and see
+ that it exists and so continue with the deletion. To prevent this from
+ happening we add a `sequence_number` column to
+ `state_groups_pending_deletion`, and during deletion we ensure that for a
+ state group we're about to delete that the sequence number doesn't change
+ between steps (a) and (d). So long as we always bump the sequence number
+ whenever an event may become used the race can never happen.
+ """
+
+ # How long to wait before we delete state groups. This should be long enough
+ # for any in-flight events to be persisted. If events take longer to persist
+ # and any of the state groups they reference have been deleted, then the
+ # event will fail to persist (as well as any event in the same batch).
+ DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ self._clock = hs.get_clock()
+ self.db_pool = database
+ self._instance_name = hs.get_instance_name()
+
+ with db_conn.cursor(txn_name="_clear_existing_persising") as txn:
+ self._clear_existing_persising(txn)
+
+ def _clear_existing_persising(self, txn: LoggingTransaction) -> None:
+ """On startup we clear any entries in `state_groups_persisting` that
+ match our instance name, in case of a previous unclean shutdown"""
+
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="state_groups_persisting",
+ keyvalues={"instance_name": self._instance_name},
+ )
+
+ async def check_state_groups_and_bump_deletion(
+ self, state_groups: AbstractSet[int]
+ ) -> Collection[int]:
+ """Checks to make sure that the state groups haven't been deleted, and
+ if they're pending deletion we delay it (allowing time for any event
+ that will use them to finish persisting).
+
+ Returns:
+ The state groups that are missing, if any.
+ """
+
+ return await self.db_pool.runInteraction(
+ "check_state_groups_and_bump_deletion",
+ self._check_state_groups_and_bump_deletion_txn,
+ state_groups,
+ # We don't need to lock if we're just doing a quick check, as the
+ # lock doesn't prevent any races here.
+ lock=False,
+ )
+
+ def _check_state_groups_and_bump_deletion_txn(
+ self, txn: LoggingTransaction, state_groups: AbstractSet[int], lock: bool = True
+ ) -> Collection[int]:
+ """Checks to make sure that the state groups haven't been deleted, and
+ if they're pending deletion we delay it (allowing time for any event
+ that will use them to finish persisting).
+
+ The `lock` flag sets if we should lock the `state_group` rows we're
+ checking, which we should do when storing new groups.
+
+ Returns:
+ The state groups that are missing, if any.
+ """
+
+ existing_state_groups = self._get_existing_groups_with_lock(
+ txn, state_groups, lock=lock
+ )
+
+ self._bump_deletion_txn(txn, existing_state_groups)
+
+ missing_state_groups = state_groups - existing_state_groups
+ if missing_state_groups:
+ return missing_state_groups
+
+ return ()
+
+ def _bump_deletion_txn(
+ self, txn: LoggingTransaction, state_groups: Collection[int]
+ ) -> None:
+ """Update any pending deletions of the state group that they may now be
+ referenced."""
+
+ if not state_groups:
+ return
+
+ now = self._clock.time_msec()
+ if isinstance(self.db_pool.engine, PostgresEngine):
+ clause, args = make_in_list_sql_clause(
+ self.db_pool.engine, "state_group", state_groups
+ )
+ sql = f"""
+ UPDATE state_groups_pending_deletion
+ SET sequence_number = DEFAULT, insertion_ts = ?
+ WHERE {clause}
+ """
+ args.insert(0, now)
+ txn.execute(sql, args)
+ else:
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="state_groups_pending_deletion",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+ if not rows:
+ return
+
+ state_groups_to_update = [state_group for (state_group,) in rows]
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="state_groups_pending_deletion",
+ column="state_group",
+ values=state_groups_to_update,
+ keyvalues={},
+ )
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_pending_deletion",
+ keys=("state_group", "insertion_ts"),
+ values=[(state_group, now) for state_group in state_groups_to_update],
+ )
+
+ def _get_existing_groups_with_lock(
+ self, txn: LoggingTransaction, state_groups: Collection[int], lock: bool = True
+ ) -> AbstractSet[int]:
+ """Return which of the given state groups are in the database, and locks
+ those rows with `KEY SHARE` to ensure they don't get concurrently
+ deleted (if `lock` is true)."""
+ clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups)
+
+ sql = f"""
+ SELECT id FROM state_groups
+ WHERE {clause}
+ """
+ if lock and isinstance(self.db_pool.engine, PostgresEngine):
+ # On postgres we add a row level lock to the rows to ensure that we
+ # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will
+ # not conflict with other read
+ sql += """
+ FOR KEY SHARE
+ """
+
+ txn.execute(sql, args)
+ return {state_group for (state_group,) in txn}
+
+ @contextlib.asynccontextmanager
+ async def persisting_state_group_references(
+ self, event_and_contexts: Collection[Tuple[EventBase, EventContext]]
+ ) -> AsyncIterator[None]:
+ """Wraps the persistence of the given events and contexts, ensuring that
+ any state groups referenced still exist and that they don't get deleted
+ during this."""
+
+ referenced_state_groups: Set[int] = set()
+ for event, ctx in event_and_contexts:
+ if ctx.rejected or event.internal_metadata.is_outlier():
+ continue
+
+ assert ctx.state_group is not None
+
+ referenced_state_groups.add(ctx.state_group)
+
+ if ctx.state_group_before_event:
+ referenced_state_groups.add(ctx.state_group_before_event)
+
+ if not referenced_state_groups:
+ # We don't reference any state groups, so nothing to do
+ yield
+ return
+
+ await self.db_pool.runInteraction(
+ "mark_state_groups_as_persisting",
+ self._mark_state_groups_as_persisting_txn,
+ referenced_state_groups,
+ )
+
+ error = True
+ try:
+ yield None
+ error = False
+ finally:
+ await self.db_pool.runInteraction(
+ "finish_persisting",
+ self._finish_persisting_txn,
+ referenced_state_groups,
+ error=error,
+ )
+
+ def _mark_state_groups_as_persisting_txn(
+ self, txn: LoggingTransaction, state_groups: Set[int]
+ ) -> None:
+ """Marks the given state groups as being persisted."""
+
+ existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
+ missing_state_groups = state_groups - existing_state_groups
+ if missing_state_groups:
+ raise Exception(
+ f"state groups have been deleted: {shortstr(missing_state_groups)}"
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_persisting",
+ keys=("state_group", "instance_name"),
+ values=[(state_group, self._instance_name) for state_group in state_groups],
+ )
+
+ def _finish_persisting_txn(
+ self, txn: LoggingTransaction, state_groups: Collection[int], error: bool
+ ) -> None:
+ """Mark the state groups as having finished persistence.
+
+ If `error` is true then we assume the state groups were not persisted,
+ and so we do not clear them from the pending deletion table.
+ """
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="state_groups_persisting",
+ column="state_group",
+ values=state_groups,
+ keyvalues={"instance_name": self._instance_name},
+ )
+
+ if error:
+ # The state groups may or may not have been persisted, so we need to
+ # bump the deletion to ensure we recheck if they have become
+ # referenced.
+ self._bump_deletion_txn(txn, state_groups)
+ return
+
+ self.db_pool.simple_delete_many_batch_txn(
+ txn,
+ table="state_groups_pending_deletion",
+ keys=("state_group",),
+ values=[(state_group,) for state_group in state_groups],
+ )
+
+ async def mark_state_groups_as_pending_deletion(
+ self, state_groups: Collection[int]
+ ) -> None:
+ """Mark the given state groups as pending deletion.
+
+ If any of the state groups are already pending deletion, then those records are
+ left as is.
+ """
+
+ await self.db_pool.runInteraction(
+ "mark_state_groups_as_pending_deletion",
+ self._mark_state_groups_as_pending_deletion_txn,
+ state_groups,
+ )
+
+ def _mark_state_groups_as_pending_deletion_txn(
+ self,
+ txn: LoggingTransaction,
+ state_groups: Collection[int],
+ ) -> None:
+ sql = """
+ INSERT INTO state_groups_pending_deletion (state_group, insertion_ts)
+ VALUES %s
+ ON CONFLICT (state_group)
+ DO NOTHING
+ """
+
+ now = self._clock.time_msec()
+ rows = [
+ (
+ state_group,
+ now,
+ )
+ for state_group in state_groups
+ ]
+ if isinstance(txn.database_engine, PostgresEngine):
+ txn.execute_values(sql % ("?",), rows, fetch=False)
+ else:
+ txn.execute_batch(sql % ("(?, ?)",), rows)
+
+ async def mark_state_groups_as_used(self, state_groups: Collection[int]) -> None:
+ """Mark the given state groups as now being referenced"""
+
+ await self.db_pool.simple_delete_many(
+ table="state_groups_pending_deletion",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ desc="mark_state_groups_as_used",
+ )
+
+ async def get_pending_deletions(
+ self, state_groups: Collection[int]
+ ) -> Mapping[int, int]:
+ """Get which state groups are pending deletion.
+
+ Returns:
+ a mapping from state groups that are pending deletion to their
+ sequence number
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
+ table="state_groups_pending_deletion",
+ column="state_group",
+ iterable=state_groups,
+ retcols=("state_group", "sequence_number"),
+ keyvalues={},
+ desc="get_pending_deletions",
+ )
+
+ return dict(rows)
+
+ def get_state_groups_ready_for_potential_deletion_txn(
+ self,
+ txn: LoggingTransaction,
+ state_groups_to_sequence_numbers: Mapping[int, int],
+ ) -> Collection[int]:
+ """Given a set of state groups, return which state groups can
+ potentially be deleted.
+
+ The state groups must have been checked to see if they remain
+ unreferenced before calling this function.
+
+ Note: This must be called within the same transaction that the state
+ groups are deleted.
+
+ Args:
+ state_groups_to_sequence_numbers: The state groups, and the sequence
+ numbers from before the state groups were checked to see if they
+ were unreferenced.
+
+ Returns:
+ The subset of state groups that can safely be deleted
+
+ """
+
+ if not state_groups_to_sequence_numbers:
+ return state_groups_to_sequence_numbers
+
+ if isinstance(self.db_pool.engine, PostgresEngine):
+ # On postgres we want to lock the rows FOR UPDATE as early as
+ # possible to help conflicts.
+ clause, args = make_in_list_sql_clause(
+ self.db_pool.engine, "id", state_groups_to_sequence_numbers
+ )
+ sql = f"""
+ SELECT id FROM state_groups
+ WHERE {clause}
+ FOR UPDATE
+ """
+ txn.execute(sql, args)
+
+ # Check the deletion status in the DB of the given state groups
+ clause, args = make_in_list_sql_clause(
+ self.db_pool.engine,
+ column="state_group",
+ iterable=state_groups_to_sequence_numbers,
+ )
+
+ sql = f"""
+ SELECT state_group, insertion_ts, sequence_number FROM (
+ SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion
+ UNION
+ SELECT state_group, null, null FROM state_groups_persisting
+ ) AS s
+ WHERE {clause}
+ """
+
+ txn.execute(sql, args)
+
+ # The above query will return potentially two rows per state group (one
+ # for each table), so we track which state groups have enough time
+ # elapsed and which are not ready to be persisted.
+ ready_to_be_deleted = set()
+ not_ready_to_be_deleted = set()
+
+ now = self._clock.time_msec()
+ for state_group, insertion_ts, sequence_number in txn:
+ if insertion_ts is None:
+ # A null insertion_ts means that we are currently persisting
+ # events that reference the state group, so we don't delete
+ # them.
+ not_ready_to_be_deleted.add(state_group)
+ continue
+
+ # We know this can't be None if insertion_ts is not None
+ assert sequence_number is not None
+
+ # Check if the sequence number has changed, if it has then it
+ # indicates that the state group may have become referenced since we
+ # checked.
+ if state_groups_to_sequence_numbers[state_group] != sequence_number:
+ not_ready_to_be_deleted.add(state_group)
+ continue
+
+ if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS:
+ # Not enough time has elapsed to allow us to delete.
+ not_ready_to_be_deleted.add(state_group)
+ continue
+
+ ready_to_be_deleted.add(state_group)
+
+ can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted
+ if not_ready_to_be_deleted:
+ # If there are any state groups that aren't ready to be deleted,
+ # then we also need to remove any state groups that are referenced
+ # by them.
+ clause, args = make_in_list_sql_clause(
+ self.db_pool.engine,
+ column="state_group",
+ iterable=state_groups_to_sequence_numbers,
+ )
+ sql = f"""
+ WITH RECURSIVE ancestors(state_group) AS (
+ SELECT DISTINCT prev_state_group
+ FROM state_group_edges WHERE {clause}
+ UNION
+ SELECT prev_state_group
+ FROM state_group_edges
+ INNER JOIN ancestors USING (state_group)
+ )
+ SELECT state_group FROM ancestors
+ """
+ txn.execute(sql, args)
+
+ can_be_deleted.difference_update(state_group for (state_group,) in txn)
+
+ return can_be_deleted
+
+ async def get_next_state_group_collection_to_delete(
+ self,
+ ) -> Optional[Tuple[str, Mapping[int, int]]]:
+ """Get the next set of state groups to try and delete
+
+ Returns:
+ 2-tuple of room_id and mapping of state groups to sequence number.
+ """
+ return await self.db_pool.runInteraction(
+ "get_next_state_group_collection_to_delete",
+ self._get_next_state_group_collection_to_delete_txn,
+ )
+
+ def _get_next_state_group_collection_to_delete_txn(
+ self,
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, Mapping[int, int]]]:
+ """Implementation of `get_next_state_group_collection_to_delete`"""
+
+ # We want to return chunks of state groups that were marked for deletion
+ # at the same time (this isn't necessary, just more efficient). We do
+ # this by looking for the oldest insertion_ts, and then pulling out all
+ # rows that have the same insertion_ts (and room ID).
+ now = self._clock.time_msec()
+
+ sql = """
+ SELECT room_id, insertion_ts
+ FROM state_groups_pending_deletion AS sd
+ INNER JOIN state_groups AS sg ON (id = sd.state_group)
+ LEFT JOIN state_groups_persisting AS sp USING (state_group)
+ WHERE insertion_ts < ? AND sp.state_group IS NULL
+ ORDER BY insertion_ts
+ LIMIT 1
+ """
+ txn.execute(sql, (now - self.DELAY_BEFORE_DELETION_MS,))
+ row = txn.fetchone()
+ if not row:
+ return None
+
+ (room_id, insertion_ts) = row
+
+ sql = """
+ SELECT state_group, sequence_number
+ FROM state_groups_pending_deletion AS sd
+ INNER JOIN state_groups AS sg ON (id = sd.state_group)
+ LEFT JOIN state_groups_persisting AS sp USING (state_group)
+ WHERE room_id = ? AND insertion_ts = ? AND sp.state_group IS NULL
+ ORDER BY insertion_ts
+ """
+ txn.execute(sql, (room_id, insertion_ts))
+
+ return room_id, dict(txn)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index d4ac74c1ee..c1a66dcba0 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -22,10 +22,10 @@
import logging
from typing import (
TYPE_CHECKING,
- Collection,
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -36,7 +36,10 @@ import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
+from synapse.events.snapshot import (
+ UnpersistedEventContext,
+ UnpersistedEventContextBase,
+)
from synapse.logging.opentracing import tag_args, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -45,6 +48,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
@@ -55,6 +59,7 @@ from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.state.deletion import StateDeletionDataStore
logger = logging.getLogger(__name__)
@@ -83,8 +88,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
+ state_deletion_store: "StateDeletionDataStore",
):
super().__init__(database, db_conn, hs)
+ self._state_deletion_store = state_deletion_store
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -284,7 +291,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- state_filter = state_filter or StateFilter.all()
+ if state_filter is None:
+ state_filter = StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split()
@@ -466,14 +474,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
A list of state groups
"""
- is_in_db = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
+
+ # We need to check that the prev group isn't about to be deleted
+ is_missing = (
+ self._state_deletion_store._check_state_groups_and_bump_deletion_txn(
+ txn,
+ {prev_group},
+ )
)
- if not is_in_db:
+ if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@@ -545,6 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
for key, state_id in context.state_delta_due_to_event.items()
],
)
+
return events_and_context
return await self.db_pool.runInteraction(
@@ -600,14 +610,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group if successfully created, or None if the state
needs to be persisted as a full state.
"""
- is_in_db = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
+
+ # We need to check that the prev group isn't about to be deleted
+ is_missing = (
+ self._state_deletion_store._check_state_groups_and_bump_deletion_txn(
+ txn,
+ {prev_group},
+ )
)
- if not is_in_db:
+ if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@@ -725,8 +736,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete: Collection[int]
- ) -> None:
+ self,
+ room_id: str,
+ state_groups_to_sequence_numbers: Mapping[int, int],
+ ) -> bool:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -734,21 +747,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
room_id: The room the state groups belong to (must all be in the
same room).
state_groups_to_delete: Set of all state groups to delete.
+
+ Returns:
+ Whether any state groups were actually deleted.
"""
- await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
- state_groups_to_delete,
+ state_groups_to_sequence_numbers,
)
def _purge_unreferenced_state_groups(
self,
txn: LoggingTransaction,
room_id: str,
- state_groups_to_delete: Collection[int],
- ) -> None:
+ state_groups_to_sequence_numbers: Mapping[int, int],
+ ) -> bool:
+ state_groups_to_delete = self._state_deletion_store.get_state_groups_ready_for_potential_deletion_txn(
+ txn, state_groups_to_sequence_numbers
+ )
+
+ if not state_groups_to_delete:
+ return False
+
logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
@@ -767,7 +790,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
remaining_state_groups = {
state_group
- for state_group, in rows
+ for (state_group,) in rows
if state_group not in state_groups_to_delete
}
@@ -804,13 +827,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
logger.info("[purge] removing redundant state groups")
txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
+ [(sg,) for sg in state_groups_to_delete],
+ )
+ txn.execute_batch(
+ "DELETE FROM state_group_edges WHERE state_group = ?",
+ [(sg,) for sg in state_groups_to_delete],
)
txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
+ [(sg,) for sg in state_groups_to_delete],
+ )
+ txn.execute_batch(
+ "DELETE FROM state_groups_pending_deletion WHERE state_group = ?",
+ [(sg,) for sg in state_groups_to_delete],
)
+ return True
+
@trace
@tag_args
async def get_previous_state_groups(
@@ -829,7 +862,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges",
- column="prev_state_group",
+ column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group", "prev_state_group"),
@@ -839,60 +872,77 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return dict(rows)
- async def purge_room_state(
- self, room_id: str, state_groups_to_delete: Collection[int]
- ) -> None:
- """Deletes all record of a room from state tables
+ @trace
+ @tag_args
+ async def get_next_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
+ """Fetch the groups that have the given state groups as their previous
+ state groups.
Args:
- room_id:
- state_groups_to_delete: State groups to delete
+ state_groups
+
+ Returns:
+ A mapping from state group to previous state group.
"""
- logger.info("[purge] Starting state purge")
- await self.db_pool.runInteraction(
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_next_state_groups",
+ ),
+ )
+
+ return dict(rows)
+
+ async def purge_room_state(self, room_id: str) -> None:
+ return await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
- state_groups_to_delete,
)
- logger.info("[purge] Done with state purge")
def _purge_room_state_txn(
self,
txn: LoggingTransaction,
room_id: str,
- state_groups_to_delete: Collection[int],
) -> None:
- # first we have to delete the state groups states
- logger.info("[purge] removing %s from state_groups_state", room_id)
+ # Delete all edges that reference a state group linked to room_id
+ logger.info("[purge] removing %s from state_group_edges", room_id)
- self.db_pool.simple_delete_many_txn(
- txn,
- table="state_groups_state",
- column="state_group",
- values=state_groups_to_delete,
- keyvalues={},
- )
+ if isinstance(self.database_engine, PostgresEngine):
+ # Disable statement timeouts for this transaction; purging rooms can
+ # take a while!
+ txn.execute("SET LOCAL statement_timeout = 0")
- # ... and the state group edges
- logger.info("[purge] removing %s from state_group_edges", room_id)
+ txn.execute(
+ """
+ DELETE FROM state_group_edges AS sge WHERE sge.state_group IN (
+ SELECT id FROM state_groups AS sg WHERE sg.room_id = ?
+ )""",
+ (room_id,),
+ )
- self.db_pool.simple_delete_many_txn(
- txn,
- table="state_group_edges",
- column="state_group",
- values=state_groups_to_delete,
- keyvalues={},
+ # state_groups_state table has a room_id column but no index on it, unlike state_groups,
+ # so we delete them by matching the room_id through the state_groups table.
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+ txn.execute(
+ """
+ DELETE FROM state_groups_state AS sgs WHERE sgs.state_group IN (
+ SELECT id FROM state_groups AS sg WHERE sg.room_id = ?
+ )""",
+ (room_id,),
)
- # ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
-
- self.db_pool.simple_delete_many_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="state_groups",
- column="id",
- values=state_groups_to_delete,
- keyvalues={},
+ keyvalues={"room_id": room_id},
)
|