summary refs log tree commit diff
path: root/synapse/storage/background_updates.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-07-05 10:43:19 +0100
committerGitHub <noreply@github.com>2023-07-05 09:43:19 +0000
commit95a96b21eb98c638ae36814ec74ba468226e373c (patch)
tree96dd63eb54af01b3ced5a3611fa4600f4788d202 /synapse/storage/background_updates.py
parentuse Image.LANCZOS instead of Image.ANTIALIAS for thumbnail resize (#15876) (diff)
downloadsynapse-95a96b21eb98c638ae36814ec74ba468226e373c.tar.xz
Add foreign key constraint to `event_forward_extremities`. (#15751)
Diffstat (limited to 'synapse/storage/background_updates.py')
-rw-r--r--synapse/storage/background_updates.py335
1 files changed, 334 insertions, 1 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index edc97a9d61..5dce0a0159 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -11,8 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from enum import IntEnum
+from enum import Enum, IntEnum
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
@@ -24,12 +25,16 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
+    Tuple,
     Type,
 )
 
 import attr
+from pydantic import BaseModel
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
@@ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 
 
+class Constraint(metaclass=abc.ABCMeta):
+    """Base class representing different constraints.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    @abc.abstractmethod
+    def make_check_clause(self, table: str) -> str:
+        """Returns an SQL expression that checks the row passes the constraint."""
+        pass
+
+    @abc.abstractmethod
+    def make_constraint_clause_postgres(self) -> str:
+        """Returns an SQL clause for creating the constraint.
+
+        Only used on Postgres DBs
+        """
+        pass
+
+
+@attr.s(auto_attribs=True)
+class ForeignKeyConstraint(Constraint):
+    """A foreign key constraint.
+
+    Attributes:
+        referenced_table: The "parent" table name.
+        columns: The list of mappings of columns from table to referenced table
+    """
+
+    referenced_table: str
+    columns: Sequence[Tuple[str, str]]
+
+    def make_check_clause(self, table: str) -> str:
+        join_clause = " AND ".join(
+            f"{col1} = {table}.{col2}" for col1, col2 in self.columns
+        )
+        return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"
+
+    def make_constraint_clause_postgres(self) -> str:
+        column1_list = ", ".join(col1 for col1, col2 in self.columns)
+        column2_list = ", ".join(col2 for col1, col2 in self.columns)
+        return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"
+
+
+@attr.s(auto_attribs=True)
+class NotNullConstraint(Constraint):
+    """A NOT NULL column constraint"""
+
+    column: str
+
+    def make_check_clause(self, table: str) -> str:
+        return f"{self.column} IS NOT NULL"
+
+    def make_constraint_clause_postgres(self) -> str:
+        return f"CHECK ({self.column} IS NOT NULL)"
+
+
+class ValidateConstraintProgress(BaseModel):
+    """The format of the progress JSON for validate constraint background
+    updates.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    class State(str, Enum):
+        check = "check"
+        validate = "validate"
+
+    state: State = State.validate
+    lower_bound: Sequence[Any] = ()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _BackgroundUpdateHandler:
     """A handler for a given background update.
@@ -740,6 +817,179 @@ class BackgroundUpdater:
         logger.info("Adding index %s to %s", index_name, table)
         await self.db_pool.runWithConnection(runner)
 
+    def register_background_validate_constraint_and_delete_rows(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+    ) -> None:
+        """Helper for store classes to do a background validate constraint, and
+        delete rows that do not pass the constraint check.
+
+        Note: This deletes rows that don't match the constraint. This may not be
+        appropriate in all situations, and so the suitability of using this
+        method should be considered on a case-by-case basis.
+
+        This only applies on PostgreSQL.
+
+        For SQLite the table gets recreated as part of the schema delta and the
+        data is copied over synchronously (or whatever the correct way to
+        describe it as).
+
+        Args:
+            update_name: The name of the background update.
+            table: The table with the invalid constraint.
+            constraint_name: The name of the constraint
+            constraint: A `Constraint` object matching the type of constraint.
+            unique_columns: A sequence of columns that form a unique constraint
+              on the table. Used to iterate over the table.
+        """
+
+        assert isinstance(
+            self.db_pool.engine, engines.PostgresEngine
+        ), "validate constraint background update registered for non-Postres database"
+
+        async def updater(progress: JsonDict, batch_size: int) -> int:
+            return await self.validate_constraint_and_delete_in_background(
+                update_name=update_name,
+                table=table,
+                constraint_name=constraint_name,
+                constraint=constraint,
+                unique_columns=unique_columns,
+                progress=progress,
+                batch_size=batch_size,
+            )
+
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
+
+    async def validate_constraint_and_delete_in_background(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        progress: JsonDict,
+        batch_size: int,
+    ) -> int:
+        """Validates a table constraint that has been marked as `NOT VALID`,
+        deleting rows that don't pass the constraint check.
+
+        This will delete rows that do not meet the validation check.
+
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        """
+
+        # We validate the constraint by:
+        #   1. Trying to validate the constraint as is. If this succeeds then
+        #      we're done.
+        #   2. Otherwise, we manually scan the table to remove rows that don't
+        #      match the constraint.
+        #   3. We try re-validating the constraint.
+
+        parsed_progress = ValidateConstraintProgress.parse_obj(progress)
+
+        if parsed_progress.state == ValidateConstraintProgress.State.check:
+            return_columns = ", ".join(unique_columns)
+            order_columns = ", ".join(unique_columns)
+
+            where_clause = ""
+            args: List[Any] = []
+            if parsed_progress.lower_bound:
+                where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
+                args.extend(parsed_progress.lower_bound)
+
+            args.append(batch_size)
+
+            sql = f"""
+                SELECT
+                    {return_columns},
+                    {constraint.make_check_clause(table)} AS check
+                FROM {table}
+                {where_clause}
+                ORDER BY {order_columns}
+                LIMIT ?
+            """
+
+            def validate_constraint_in_background_check(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql, args)
+                rows = txn.fetchall()
+
+                new_progress = parsed_progress.copy()
+
+                if not rows:
+                    new_progress.state = ValidateConstraintProgress.State.validate
+                    self._background_update_progress_txn(
+                        txn, update_name, new_progress.dict()
+                    )
+                    return
+
+                new_progress.lower_bound = rows[-1][:-1]
+
+                to_delete = [row[:-1] for row in rows if not row[-1]]
+
+                if to_delete:
+                    logger.warning(
+                        "Deleting %d rows that do not pass new constraint",
+                        len(to_delete),
+                    )
+
+                    self.db_pool.simple_delete_many_batch_txn(
+                        txn, table=table, keys=unique_columns, values=to_delete
+                    )
+
+                self._background_update_progress_txn(
+                    txn, update_name, new_progress.dict()
+                )
+
+            await self.db_pool.runInteraction(
+                "validate_constraint_in_background_check",
+                validate_constraint_in_background_check,
+            )
+
+            return batch_size
+
+        elif parsed_progress.state == ValidateConstraintProgress.State.validate:
+            sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"
+
+            def validate_constraint_in_background_validate(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql)
+
+            try:
+                await self.db_pool.runInteraction(
+                    "validate_constraint_in_background_validate",
+                    validate_constraint_in_background_validate,
+                )
+
+                await self._end_background_update(update_name)
+            except self.db_pool.engine.module.IntegrityError as e:
+                # If we get an integrity error here, then we go back and recheck the table.
+                logger.warning("Integrity error when validating constraint: %s", e)
+                await self._background_update_progress(
+                    update_name,
+                    ValidateConstraintProgress(
+                        state=ValidateConstraintProgress.State.check
+                    ).dict(),
+                )
+
+            return batch_size
+        else:
+            raise Exception(
+                f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background"
+            )
+
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
 
@@ -795,3 +1045,86 @@ class BackgroundUpdater:
             keyvalues={"update_name": update_name},
             updatevalues={"progress_json": progress_json},
         )
+
+
+def run_validate_constraint_and_delete_rows_schema_delta(
+    txn: "LoggingTransaction",
+    ordering: int,
+    update_name: str,
+    table: str,
+    constraint_name: str,
+    constraint: Constraint,
+    sqlite_table_name: str,
+    sqlite_table_schema: str,
+) -> None:
+    """Runs a schema delta to add a constraint to the table. This should be run
+    in a schema delta file.
+
+    For PostgreSQL the constraint is added and validated in the background.
+
+    For SQLite the table is recreated and data copied across immediately. This
+    is done by the caller passing in a script to create the new table. Note that
+    table indexes and triggers are copied over automatically.
+
+    There must be a corresponding call to
+    `register_background_validate_constraint_and_delete_rows` to register the
+    background update in one of the data store classes.
+
+    Attributes:
+        txn ordering, update_name: For adding a row to background_updates table.
+        table: The table to add constraint to. constraint_name: The name of the
+        new constraint constraint: A `Constraint` object describing the
+        constraint sqlite_table_name: For SQLite the name of the empty copy of
+        table sqlite_table_schema: A SQL script for creating the above table.
+    """
+
+    if isinstance(txn.database_engine, PostgresEngine):
+        # For postgres we can just add the constraint and mark it as NOT VALID,
+        # and then insert a background update to go and check the validity in
+        # the background.
+        txn.execute(
+            f"""
+            ALTER TABLE {table}
+            ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
+            NOT VALID
+            """
+        )
+
+        txn.execute(
+            "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
+            (ordering, update_name),
+        )
+    else:
+        # For SQLite, we:
+        #   1. fetch all indexes/triggers/etc related to the table
+        #   2. create an empty copy of the table
+        #   3. copy across the rows (that satisfy the check)
+        #   4. replace the old table with the new able.
+        #   5. add back all the indexes/triggers/etc
+
+        # Fetch the indexes/triggers/etc. Note that `sql` column being null is
+        # due to indexes being auto created based on the class definition (e.g.
+        # PRIMARY KEY), and so don't need to be recreated.
+        txn.execute(
+            """
+            SELECT sql FROM sqlite_master
+            WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL
+            """,
+            (table,),
+        )
+        extras = [row[0] for row in txn]
+
+        txn.execute(sqlite_table_schema)
+
+        sql = f"""
+            INSERT INTO {sqlite_table_name} SELECT * FROM {table}
+            WHERE {constraint.make_check_clause(table)}
+        """
+
+        txn.execute(sql)
+
+        txn.execute(f"DROP TABLE {table}")
+        txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")
+
+        for extra in extras:
+            txn.execute(extra)