diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index a803ada8ad..e126a2e0c5 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -61,6 +61,7 @@ from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpda
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import EventPushActionsStore
from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore,
@@ -239,6 +240,7 @@ class Store(
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
+ EventFederationWorkerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index edc97a9d61..5dce0a0159 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from enum import IntEnum
+from enum import Enum, IntEnum
from types import TracebackType
from typing import (
TYPE_CHECKING,
@@ -24,12 +25,16 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
+ Tuple,
Type,
)
import attr
+from pydantic import BaseModel
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
@@ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+class Constraint(metaclass=abc.ABCMeta):
+ """Base class representing different constraints.
+
+ Used by `register_background_validate_constraint_and_delete_rows`.
+ """
+
+ @abc.abstractmethod
+ def make_check_clause(self, table: str) -> str:
+ """Returns an SQL expression that checks the row passes the constraint."""
+ pass
+
+ @abc.abstractmethod
+ def make_constraint_clause_postgres(self) -> str:
+ """Returns an SQL clause for creating the constraint.
+
+ Only used on Postgres DBs
+ """
+ pass
+
+
+@attr.s(auto_attribs=True)
+class ForeignKeyConstraint(Constraint):
+ """A foreign key constraint.
+
+ Attributes:
+ referenced_table: The "parent" table name.
+ columns: The list of mappings of columns from table to referenced table
+ """
+
+ referenced_table: str
+ columns: Sequence[Tuple[str, str]]
+
+ def make_check_clause(self, table: str) -> str:
+ join_clause = " AND ".join(
+ f"{col1} = {table}.{col2}" for col1, col2 in self.columns
+ )
+ return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"
+
+ def make_constraint_clause_postgres(self) -> str:
+ column1_list = ", ".join(col1 for col1, col2 in self.columns)
+ column2_list = ", ".join(col2 for col1, col2 in self.columns)
+ return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"
+
+
+@attr.s(auto_attribs=True)
+class NotNullConstraint(Constraint):
+ """A NOT NULL column constraint"""
+
+ column: str
+
+ def make_check_clause(self, table: str) -> str:
+ return f"{self.column} IS NOT NULL"
+
+ def make_constraint_clause_postgres(self) -> str:
+ return f"CHECK ({self.column} IS NOT NULL)"
+
+
+class ValidateConstraintProgress(BaseModel):
+ """The format of the progress JSON for validate constraint background
+ updates.
+
+ Used by `register_background_validate_constraint_and_delete_rows`.
+ """
+
+ class State(str, Enum):
+ check = "check"
+ validate = "validate"
+
+ state: State = State.validate
+ lower_bound: Sequence[Any] = ()
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler:
"""A handler for a given background update.
@@ -740,6 +817,179 @@ class BackgroundUpdater:
logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner)
+ def register_background_validate_constraint_and_delete_rows(
+ self,
+ update_name: str,
+ table: str,
+ constraint_name: str,
+ constraint: Constraint,
+ unique_columns: Sequence[str],
+ ) -> None:
+ """Helper for store classes to do a background validate constraint, and
+ delete rows that do not pass the constraint check.
+
+ Note: This deletes rows that don't match the constraint. This may not be
+ appropriate in all situations, and so the suitability of using this
+ method should be considered on a case-by-case basis.
+
+ This only applies on PostgreSQL.
+
+ For SQLite the table gets recreated as part of the schema delta and the
+ data is copied over synchronously (or whatever the correct way to
+ describe it as).
+
+ Args:
+ update_name: The name of the background update.
+ table: The table with the invalid constraint.
+ constraint_name: The name of the constraint
+ constraint: A `Constraint` object matching the type of constraint.
+ unique_columns: A sequence of columns that form a unique constraint
+ on the table. Used to iterate over the table.
+ """
+
+ assert isinstance(
+ self.db_pool.engine, engines.PostgresEngine
+ ), "validate constraint background update registered for non-Postres database"
+
+ async def updater(progress: JsonDict, batch_size: int) -> int:
+ return await self.validate_constraint_and_delete_in_background(
+ update_name=update_name,
+ table=table,
+ constraint_name=constraint_name,
+ constraint=constraint,
+ unique_columns=unique_columns,
+ progress=progress,
+ batch_size=batch_size,
+ )
+
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
+
+ async def validate_constraint_and_delete_in_background(
+ self,
+ update_name: str,
+ table: str,
+ constraint_name: str,
+ constraint: Constraint,
+ unique_columns: Sequence[str],
+ progress: JsonDict,
+ batch_size: int,
+ ) -> int:
+ """Validates a table constraint that has been marked as `NOT VALID`,
+ deleting rows that don't pass the constraint check.
+
+ This will delete rows that do not meet the validation check.
+
+ update_name: str,
+ table: str,
+ constraint_name: str,
+ constraint: Constraint,
+ unique_columns: Sequence[str],
+ """
+
+ # We validate the constraint by:
+ # 1. Trying to validate the constraint as is. If this succeeds then
+ # we're done.
+ # 2. Otherwise, we manually scan the table to remove rows that don't
+ # match the constraint.
+ # 3. We try re-validating the constraint.
+
+ parsed_progress = ValidateConstraintProgress.parse_obj(progress)
+
+ if parsed_progress.state == ValidateConstraintProgress.State.check:
+ return_columns = ", ".join(unique_columns)
+ order_columns = ", ".join(unique_columns)
+
+ where_clause = ""
+ args: List[Any] = []
+ if parsed_progress.lower_bound:
+ where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
+ args.extend(parsed_progress.lower_bound)
+
+ args.append(batch_size)
+
+ sql = f"""
+ SELECT
+ {return_columns},
+ {constraint.make_check_clause(table)} AS check
+ FROM {table}
+ {where_clause}
+ ORDER BY {order_columns}
+ LIMIT ?
+ """
+
+ def validate_constraint_in_background_check(
+ txn: "LoggingTransaction",
+ ) -> None:
+ txn.execute(sql, args)
+ rows = txn.fetchall()
+
+ new_progress = parsed_progress.copy()
+
+ if not rows:
+ new_progress.state = ValidateConstraintProgress.State.validate
+ self._background_update_progress_txn(
+ txn, update_name, new_progress.dict()
+ )
+ return
+
+ new_progress.lower_bound = rows[-1][:-1]
+
+ to_delete = [row[:-1] for row in rows if not row[-1]]
+
+ if to_delete:
+ logger.warning(
+ "Deleting %d rows that do not pass new constraint",
+ len(to_delete),
+ )
+
+ self.db_pool.simple_delete_many_batch_txn(
+ txn, table=table, keys=unique_columns, values=to_delete
+ )
+
+ self._background_update_progress_txn(
+ txn, update_name, new_progress.dict()
+ )
+
+ await self.db_pool.runInteraction(
+ "validate_constraint_in_background_check",
+ validate_constraint_in_background_check,
+ )
+
+ return batch_size
+
+ elif parsed_progress.state == ValidateConstraintProgress.State.validate:
+ sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"
+
+ def validate_constraint_in_background_validate(
+ txn: "LoggingTransaction",
+ ) -> None:
+ txn.execute(sql)
+
+ try:
+ await self.db_pool.runInteraction(
+ "validate_constraint_in_background_validate",
+ validate_constraint_in_background_validate,
+ )
+
+ await self._end_background_update(update_name)
+ except self.db_pool.engine.module.IntegrityError as e:
+ # If we get an integrity error here, then we go back and recheck the table.
+ logger.warning("Integrity error when validating constraint: %s", e)
+ await self._background_update_progress(
+ update_name,
+ ValidateConstraintProgress(
+ state=ValidateConstraintProgress.State.check
+ ).dict(),
+ )
+
+ return batch_size
+ else:
+ raise Exception(
+ f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background"
+ )
+
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
@@ -795,3 +1045,86 @@ class BackgroundUpdater:
keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json},
)
+
+
+def run_validate_constraint_and_delete_rows_schema_delta(
+ txn: "LoggingTransaction",
+ ordering: int,
+ update_name: str,
+ table: str,
+ constraint_name: str,
+ constraint: Constraint,
+ sqlite_table_name: str,
+ sqlite_table_schema: str,
+) -> None:
+ """Runs a schema delta to add a constraint to the table. This should be run
+ in a schema delta file.
+
+ For PostgreSQL the constraint is added and validated in the background.
+
+ For SQLite the table is recreated and data copied across immediately. This
+ is done by the caller passing in a script to create the new table. Note that
+ table indexes and triggers are copied over automatically.
+
+ There must be a corresponding call to
+ `register_background_validate_constraint_and_delete_rows` to register the
+ background update in one of the data store classes.
+
+ Attributes:
+ txn ordering, update_name: For adding a row to background_updates table.
+ table: The table to add constraint to. constraint_name: The name of the
+ new constraint constraint: A `Constraint` object describing the
+ constraint sqlite_table_name: For SQLite the name of the empty copy of
+ table sqlite_table_schema: A SQL script for creating the above table.
+ """
+
+ if isinstance(txn.database_engine, PostgresEngine):
+ # For postgres we can just add the constraint and mark it as NOT VALID,
+ # and then insert a background update to go and check the validity in
+ # the background.
+ txn.execute(
+ f"""
+ ALTER TABLE {table}
+ ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
+ NOT VALID
+ """
+ )
+
+ txn.execute(
+ "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
+ (ordering, update_name),
+ )
+ else:
+ # For SQLite, we:
+ # 1. fetch all indexes/triggers/etc related to the table
+ # 2. create an empty copy of the table
+ # 3. copy across the rows (that satisfy the check)
+ # 4. replace the old table with the new able.
+ # 5. add back all the indexes/triggers/etc
+
+ # Fetch the indexes/triggers/etc. Note that `sql` column being null is
+ # due to indexes being auto created based on the class definition (e.g.
+ # PRIMARY KEY), and so don't need to be recreated.
+ txn.execute(
+ """
+ SELECT sql FROM sqlite_master
+ WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL
+ """,
+ (table,),
+ )
+ extras = [row[0] for row in txn]
+
+ txn.execute(sqlite_table_schema)
+
+ sql = f"""
+ INSERT INTO {sqlite_table_name} SELECT * FROM {table}
+ WHERE {constraint.make_check_clause(table)}
+ """
+
+ txn.execute(sql)
+
+ txn.execute(f"DROP TABLE {table}")
+ txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")
+
+ for extra in extras:
+ txn.execute(extra)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7e49ae11bc..a1c8fb0f46 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2313,6 +2313,43 @@ class DatabasePool:
return txn.rowcount
+ @staticmethod
+ def simple_delete_many_batch_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keys: Collection[str],
+ values: Iterable[Iterable[Any]],
+ ) -> None:
+ """Executes a DELETE query on the named table.
+
+ The input is given as a list of rows, where each row is a list of values.
+ (Actually any iterable is fine.)
+
+ Args:
+ txn: The transaction to use.
+ table: string giving the table name
+ keys: list of column names
+ values: for each row, a list of values in the same order as `keys`
+ """
+
+ if isinstance(txn.database_engine, PostgresEngine):
+ # We use `execute_values` as it can be a lot faster than `execute_batch`,
+ # but it's only available on postgres.
+ sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % (
+ table,
+ ", ".join(k for k in keys),
+ )
+
+ txn.execute_values(sql, values, fetch=False)
+ else:
+ sql = "DELETE FROM %s WHERE (%s) = (%s)" % (
+ table,
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
+ )
+
+ txn.execute_batch(sql, values)
+
def get_cache_dict(
self,
db_conn: LoggingDatabaseConnection,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8b6e3c1dc7..dabe603c8c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -38,6 +38,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.background_updates import ForeignKeyConstraint
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -140,6 +141,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
+ if isinstance(self.database_engine, PostgresEngine):
+ self.db_pool.updates.register_background_validate_constraint_and_delete_rows(
+ update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+ table="event_forward_extremities",
+ constraint_name="event_forward_extremities_event_id",
+ constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+ unique_columns=("event_id", "room_id"),
+ )
+
async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5c9db7554e..2b83a69426 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -415,12 +415,6 @@ class PersistEventsStore:
backfilled=False,
)
- self._update_forward_extremities_txn(
- txn,
- new_forward_extremities=new_forward_extremities,
- max_stream_order=max_stream_order,
- )
-
# Ensure that we don't have the same event twice.
events_and_contexts = self._filter_events_and_contexts_for_duplicates(
events_and_contexts
@@ -439,6 +433,12 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
+ self._update_forward_extremities_txn(
+ txn,
+ new_forward_extremities=new_forward_extremities,
+ max_stream_order=max_stream_order,
+ )
+
self._persist_transaction_ids_txn(txn, events_and_contexts)
# Insert into event_to_state_groups.
diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
new file mode 100644
index 0000000000..f12e2a8f3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
@@ -0,0 +1,51 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This migration adds foreign key constraint to `event_forward_extremities` table.
+"""
+from synapse.storage.background_updates import (
+ ForeignKeyConstraint,
+ run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import BaseDatabaseEngine
+
+FORWARD_EXTREMITIES_TABLE_SCHEMA = """
+ CREATE TABLE event_forward_extremities2(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ UNIQUE (event_id, room_id),
+ CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id)
+ )
+"""
+
+
+def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
+ run_validate_constraint_and_delete_rows_schema_delta(
+ cur,
+ ordering=7803,
+ update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+ table="event_forward_extremities",
+ constraint_name="event_forward_extremities_event_id",
+ constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+ sqlite_table_name="event_forward_extremities2",
+ sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA,
+ )
+
+ # We can't add a similar constraint to `event_backward_extremities` as the
+ # events in there don't exist in the `events` table and `event_edges`
+ # doesn't have a unique constraint on `prev_event_id` (so we can't make a
+ # foreign key point to it).
|