summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py335
-rw-r--r--synapse/storage/database.py37
-rw-r--r--synapse/storage/databases/main/event_federation.py10
-rw-r--r--synapse/storage/databases/main/events.py12
-rw-r--r--synapse/storage/schema/main/delta/78/03event_extremities_constraints.py51
5 files changed, 438 insertions, 7 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index edc97a9d61..5dce0a0159 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -11,8 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from enum import IntEnum
+from enum import Enum, IntEnum
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
@@ -24,12 +25,16 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
+    Tuple,
     Type,
 )
 
 import attr
+from pydantic import BaseModel
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
@@ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 
 
+class Constraint(metaclass=abc.ABCMeta):
+    """Base class representing different constraints.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    @abc.abstractmethod
+    def make_check_clause(self, table: str) -> str:
+        """Returns an SQL expression that checks the row passes the constraint."""
+        pass
+
+    @abc.abstractmethod
+    def make_constraint_clause_postgres(self) -> str:
+        """Returns an SQL clause for creating the constraint.
+
+        Only used on Postgres DBs
+        """
+        pass
+
+
+@attr.s(auto_attribs=True)
+class ForeignKeyConstraint(Constraint):
+    """A foreign key constraint.
+
+    Attributes:
+        referenced_table: The "parent" table name.
+        columns: The list of mappings of columns from table to referenced table
+    """
+
+    referenced_table: str
+    columns: Sequence[Tuple[str, str]]
+
+    def make_check_clause(self, table: str) -> str:
+        join_clause = " AND ".join(
+            f"{col1} = {table}.{col2}" for col1, col2 in self.columns
+        )
+        return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"
+
+    def make_constraint_clause_postgres(self) -> str:
+        column1_list = ", ".join(col1 for col1, col2 in self.columns)
+        column2_list = ", ".join(col2 for col1, col2 in self.columns)
+        return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"
+
+
+@attr.s(auto_attribs=True)
+class NotNullConstraint(Constraint):
+    """A NOT NULL column constraint"""
+
+    column: str
+
+    def make_check_clause(self, table: str) -> str:
+        return f"{self.column} IS NOT NULL"
+
+    def make_constraint_clause_postgres(self) -> str:
+        return f"CHECK ({self.column} IS NOT NULL)"
+
+
+class ValidateConstraintProgress(BaseModel):
+    """The format of the progress JSON for validate constraint background
+    updates.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    class State(str, Enum):
+        check = "check"
+        validate = "validate"
+
+    state: State = State.validate
+    lower_bound: Sequence[Any] = ()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _BackgroundUpdateHandler:
     """A handler for a given background update.
@@ -740,6 +817,179 @@ class BackgroundUpdater:
         logger.info("Adding index %s to %s", index_name, table)
         await self.db_pool.runWithConnection(runner)
 
+    def register_background_validate_constraint_and_delete_rows(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+    ) -> None:
+        """Helper for store classes to do a background validate constraint, and
+        delete rows that do not pass the constraint check.
+
+        Note: This deletes rows that don't match the constraint. This may not be
+        appropriate in all situations, and so the suitability of using this
+        method should be considered on a case-by-case basis.
+
+        This only applies on PostgreSQL.
+
+        For SQLite the table gets recreated as part of the schema delta and the
+        data is copied over synchronously (or whatever the correct way to
+        describe it as).
+
+        Args:
+            update_name: The name of the background update.
+            table: The table with the invalid constraint.
+            constraint_name: The name of the constraint
+            constraint: A `Constraint` object matching the type of constraint.
+            unique_columns: A sequence of columns that form a unique constraint
+              on the table. Used to iterate over the table.
+        """
+
+        assert isinstance(
+            self.db_pool.engine, engines.PostgresEngine
+        ), "validate constraint background update registered for non-Postres database"
+
+        async def updater(progress: JsonDict, batch_size: int) -> int:
+            return await self.validate_constraint_and_delete_in_background(
+                update_name=update_name,
+                table=table,
+                constraint_name=constraint_name,
+                constraint=constraint,
+                unique_columns=unique_columns,
+                progress=progress,
+                batch_size=batch_size,
+            )
+
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
+
+    async def validate_constraint_and_delete_in_background(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        progress: JsonDict,
+        batch_size: int,
+    ) -> int:
+        """Validates a table constraint that has been marked as `NOT VALID`,
+        deleting rows that don't pass the constraint check.
+
+        This will delete rows that do not meet the validation check.
+
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        """
+
+        # We validate the constraint by:
+        #   1. Trying to validate the constraint as is. If this succeeds then
+        #      we're done.
+        #   2. Otherwise, we manually scan the table to remove rows that don't
+        #      match the constraint.
+        #   3. We try re-validating the constraint.
+
+        parsed_progress = ValidateConstraintProgress.parse_obj(progress)
+
+        if parsed_progress.state == ValidateConstraintProgress.State.check:
+            return_columns = ", ".join(unique_columns)
+            order_columns = ", ".join(unique_columns)
+
+            where_clause = ""
+            args: List[Any] = []
+            if parsed_progress.lower_bound:
+                where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
+                args.extend(parsed_progress.lower_bound)
+
+            args.append(batch_size)
+
+            sql = f"""
+                SELECT
+                    {return_columns},
+                    {constraint.make_check_clause(table)} AS check
+                FROM {table}
+                {where_clause}
+                ORDER BY {order_columns}
+                LIMIT ?
+            """
+
+            def validate_constraint_in_background_check(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql, args)
+                rows = txn.fetchall()
+
+                new_progress = parsed_progress.copy()
+
+                if not rows:
+                    new_progress.state = ValidateConstraintProgress.State.validate
+                    self._background_update_progress_txn(
+                        txn, update_name, new_progress.dict()
+                    )
+                    return
+
+                new_progress.lower_bound = rows[-1][:-1]
+
+                to_delete = [row[:-1] for row in rows if not row[-1]]
+
+                if to_delete:
+                    logger.warning(
+                        "Deleting %d rows that do not pass new constraint",
+                        len(to_delete),
+                    )
+
+                    self.db_pool.simple_delete_many_batch_txn(
+                        txn, table=table, keys=unique_columns, values=to_delete
+                    )
+
+                self._background_update_progress_txn(
+                    txn, update_name, new_progress.dict()
+                )
+
+            await self.db_pool.runInteraction(
+                "validate_constraint_in_background_check",
+                validate_constraint_in_background_check,
+            )
+
+            return batch_size
+
+        elif parsed_progress.state == ValidateConstraintProgress.State.validate:
+            sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"
+
+            def validate_constraint_in_background_validate(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql)
+
+            try:
+                await self.db_pool.runInteraction(
+                    "validate_constraint_in_background_validate",
+                    validate_constraint_in_background_validate,
+                )
+
+                await self._end_background_update(update_name)
+            except self.db_pool.engine.module.IntegrityError as e:
+                # If we get an integrity error here, then we go back and recheck the table.
+                logger.warning("Integrity error when validating constraint: %s", e)
+                await self._background_update_progress(
+                    update_name,
+                    ValidateConstraintProgress(
+                        state=ValidateConstraintProgress.State.check
+                    ).dict(),
+                )
+
+            return batch_size
+        else:
+            raise Exception(
+                f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background"
+            )
+
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
 
@@ -795,3 +1045,86 @@ class BackgroundUpdater:
             keyvalues={"update_name": update_name},
             updatevalues={"progress_json": progress_json},
         )
+
+
+def run_validate_constraint_and_delete_rows_schema_delta(
+    txn: "LoggingTransaction",
+    ordering: int,
+    update_name: str,
+    table: str,
+    constraint_name: str,
+    constraint: Constraint,
+    sqlite_table_name: str,
+    sqlite_table_schema: str,
+) -> None:
+    """Runs a schema delta to add a constraint to the table. This should be run
+    in a schema delta file.
+
+    For PostgreSQL the constraint is added and validated in the background.
+
+    For SQLite the table is recreated and data copied across immediately. This
+    is done by the caller passing in a script to create the new table. Note that
+    table indexes and triggers are copied over automatically.
+
+    There must be a corresponding call to
+    `register_background_validate_constraint_and_delete_rows` to register the
+    background update in one of the data store classes.
+
+    Attributes:
+        txn ordering, update_name: For adding a row to background_updates table.
+        table: The table to add constraint to. constraint_name: The name of the
+        new constraint constraint: A `Constraint` object describing the
+        constraint sqlite_table_name: For SQLite the name of the empty copy of
+        table sqlite_table_schema: A SQL script for creating the above table.
+    """
+
+    if isinstance(txn.database_engine, PostgresEngine):
+        # For postgres we can just add the constraint and mark it as NOT VALID,
+        # and then insert a background update to go and check the validity in
+        # the background.
+        txn.execute(
+            f"""
+            ALTER TABLE {table}
+            ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
+            NOT VALID
+            """
+        )
+
+        txn.execute(
+            "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
+            (ordering, update_name),
+        )
+    else:
+        # For SQLite, we:
+        #   1. fetch all indexes/triggers/etc related to the table
+        #   2. create an empty copy of the table
+        #   3. copy across the rows (that satisfy the check)
+        #   4. replace the old table with the new able.
+        #   5. add back all the indexes/triggers/etc
+
+        # Fetch the indexes/triggers/etc. Note that `sql` column being null is
+        # due to indexes being auto created based on the class definition (e.g.
+        # PRIMARY KEY), and so don't need to be recreated.
+        txn.execute(
+            """
+            SELECT sql FROM sqlite_master
+            WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL
+            """,
+            (table,),
+        )
+        extras = [row[0] for row in txn]
+
+        txn.execute(sqlite_table_schema)
+
+        sql = f"""
+            INSERT INTO {sqlite_table_name} SELECT * FROM {table}
+            WHERE {constraint.make_check_clause(table)}
+        """
+
+        txn.execute(sql)
+
+        txn.execute(f"DROP TABLE {table}")
+        txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")
+
+        for extra in extras:
+            txn.execute(extra)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7e49ae11bc..a1c8fb0f46 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2313,6 +2313,43 @@ class DatabasePool:
 
         return txn.rowcount
 
+    @staticmethod
+    def simple_delete_many_batch_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+    ) -> None:
+        """Executes a DELETE query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+        """
+
+        if isinstance(txn.database_engine, PostgresEngine):
+            # We use `execute_values` as it can be a lot faster than `execute_batch`,
+            # but it's only available on postgres.
+            sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % (
+                table,
+                ", ".join(k for k in keys),
+            )
+
+            txn.execute_values(sql, values, fetch=False)
+        else:
+            sql = "DELETE FROM %s WHERE (%s) = (%s)" % (
+                table,
+                ", ".join(k for k in keys),
+                ", ".join("?" for _ in keys),
+            )
+
+            txn.execute_batch(sql, values)
+
     def get_cache_dict(
         self,
         db_conn: LoggingDatabaseConnection,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8b6e3c1dc7..dabe603c8c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -38,6 +38,7 @@ from synapse.events import EventBase, make_event_from_dict
 from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.background_updates import ForeignKeyConstraint
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -140,6 +141,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
+        if isinstance(self.database_engine, PostgresEngine):
+            self.db_pool.updates.register_background_validate_constraint_and_delete_rows(
+                update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+                table="event_forward_extremities",
+                constraint_name="event_forward_extremities_event_id",
+                constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+                unique_columns=("event_id", "room_id"),
+            )
+
     async def get_auth_chain(
         self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5c9db7554e..2b83a69426 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -415,12 +415,6 @@ class PersistEventsStore:
                 backfilled=False,
             )
 
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremities,
-            max_stream_order=max_stream_order,
-        )
-
         # Ensure that we don't have the same event twice.
         events_and_contexts = self._filter_events_and_contexts_for_duplicates(
             events_and_contexts
@@ -439,6 +433,12 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremities,
+            max_stream_order=max_stream_order,
+        )
+
         self._persist_transaction_ids_txn(txn, events_and_contexts)
 
         # Insert into event_to_state_groups.
diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
new file mode 100644
index 0000000000..f12e2a8f3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
@@ -0,0 +1,51 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This migration adds foreign key constraint to `event_forward_extremities` table.
+"""
+from synapse.storage.background_updates import (
+    ForeignKeyConstraint,
+    run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import BaseDatabaseEngine
+
+FORWARD_EXTREMITIES_TABLE_SCHEMA = """
+    CREATE TABLE event_forward_extremities2(
+        event_id TEXT NOT NULL,
+        room_id TEXT NOT NULL,
+        UNIQUE (event_id, room_id),
+        CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id)
+    )
+"""
+
+
+def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
+    run_validate_constraint_and_delete_rows_schema_delta(
+        cur,
+        ordering=7803,
+        update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+        table="event_forward_extremities",
+        constraint_name="event_forward_extremities_event_id",
+        constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+        sqlite_table_name="event_forward_extremities2",
+        sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA,
+    )
+
+    # We can't add a similar constraint to `event_backward_extremities` as the
+    # events in there don't exist in the `events` table and `event_edges`
+    # doesn't have a unique constraint on `prev_event_id` (so we can't make a
+    # foreign key point to it).