summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15751.misc1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py2
-rw-r--r--synapse/storage/background_updates.py335
-rw-r--r--synapse/storage/database.py37
-rw-r--r--synapse/storage/databases/main/event_federation.py10
-rw-r--r--synapse/storage/databases/main/events.py12
-rw-r--r--synapse/storage/schema/main/delta/78/03event_extremities_constraints.py51
-rw-r--r--tests/storage/test_background_update.py227
-rw-r--r--tests/storage/test_event_federation.py35
9 files changed, 699 insertions, 11 deletions
diff --git a/changelog.d/15751.misc b/changelog.d/15751.misc
new file mode 100644
index 0000000000..e0ecea6c2f
--- /dev/null
+++ b/changelog.d/15751.misc
@@ -0,0 +1 @@
+Add foreign key constraint to `event_forward_extremities`.
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index a803ada8ad..e126a2e0c5 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -61,6 +61,7 @@ from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpda
 from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
 from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
 from synapse.storage.databases.main.event_push_actions import EventPushActionsStore
 from synapse.storage.databases.main.events_bg_updates import (
     EventsBackgroundUpdatesStore,
@@ -239,6 +240,7 @@ class Store(
     PresenceBackgroundUpdateStore,
     ReceiptsBackgroundUpdateStore,
     RelationsWorkerStore,
+    EventFederationWorkerStore,
 ):
     def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index edc97a9d61..5dce0a0159 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -11,8 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from enum import IntEnum
+from enum import Enum, IntEnum
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
@@ -24,12 +25,16 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
+    Tuple,
     Type,
 )
 
 import attr
+from pydantic import BaseModel
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
@@ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
 
 
+class Constraint(metaclass=abc.ABCMeta):
+    """Base class representing different constraints.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    @abc.abstractmethod
+    def make_check_clause(self, table: str) -> str:
+        """Returns an SQL expression that checks the row passes the constraint."""
+        pass
+
+    @abc.abstractmethod
+    def make_constraint_clause_postgres(self) -> str:
+        """Returns an SQL clause for creating the constraint.
+
+        Only used on Postgres DBs
+        """
+        pass
+
+
+@attr.s(auto_attribs=True)
+class ForeignKeyConstraint(Constraint):
+    """A foreign key constraint.
+
+    Attributes:
+        referenced_table: The "parent" table name.
+        columns: The list of mappings of columns from table to referenced table
+    """
+
+    referenced_table: str
+    columns: Sequence[Tuple[str, str]]
+
+    def make_check_clause(self, table: str) -> str:
+        join_clause = " AND ".join(
+            f"{col1} = {table}.{col2}" for col1, col2 in self.columns
+        )
+        return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"
+
+    def make_constraint_clause_postgres(self) -> str:
+        column1_list = ", ".join(col1 for col1, col2 in self.columns)
+        column2_list = ", ".join(col2 for col1, col2 in self.columns)
+        return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"
+
+
+@attr.s(auto_attribs=True)
+class NotNullConstraint(Constraint):
+    """A NOT NULL column constraint"""
+
+    column: str
+
+    def make_check_clause(self, table: str) -> str:
+        return f"{self.column} IS NOT NULL"
+
+    def make_constraint_clause_postgres(self) -> str:
+        return f"CHECK ({self.column} IS NOT NULL)"
+
+
+class ValidateConstraintProgress(BaseModel):
+    """The format of the progress JSON for validate constraint background
+    updates.
+
+    Used by `register_background_validate_constraint_and_delete_rows`.
+    """
+
+    class State(str, Enum):
+        check = "check"
+        validate = "validate"
+
+    state: State = State.validate
+    lower_bound: Sequence[Any] = ()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _BackgroundUpdateHandler:
     """A handler for a given background update.
@@ -740,6 +817,179 @@ class BackgroundUpdater:
         logger.info("Adding index %s to %s", index_name, table)
         await self.db_pool.runWithConnection(runner)
 
+    def register_background_validate_constraint_and_delete_rows(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+    ) -> None:
+        """Helper for store classes to do a background validate constraint, and
+        delete rows that do not pass the constraint check.
+
+        Note: This deletes rows that don't match the constraint. This may not be
+        appropriate in all situations, and so the suitability of using this
+        method should be considered on a case-by-case basis.
+
+        This only applies on PostgreSQL.
+
+        For SQLite the table gets recreated as part of the schema delta and the
+        data is copied over synchronously (or whatever the correct way to
+        describe it as).
+
+        Args:
+            update_name: The name of the background update.
+            table: The table with the invalid constraint.
+            constraint_name: The name of the constraint
+            constraint: A `Constraint` object matching the type of constraint.
+            unique_columns: A sequence of columns that form a unique constraint
+              on the table. Used to iterate over the table.
+        """
+
+        assert isinstance(
+            self.db_pool.engine, engines.PostgresEngine
+        ), "validate constraint background update registered for non-Postres database"
+
+        async def updater(progress: JsonDict, batch_size: int) -> int:
+            return await self.validate_constraint_and_delete_in_background(
+                update_name=update_name,
+                table=table,
+                constraint_name=constraint_name,
+                constraint=constraint,
+                unique_columns=unique_columns,
+                progress=progress,
+                batch_size=batch_size,
+            )
+
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
+
+    async def validate_constraint_and_delete_in_background(
+        self,
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        progress: JsonDict,
+        batch_size: int,
+    ) -> int:
+        """Validates a table constraint that has been marked as `NOT VALID`,
+        deleting rows that don't pass the constraint check.
+
+        This will delete rows that do not meet the validation check.
+
+        update_name: str,
+        table: str,
+        constraint_name: str,
+        constraint: Constraint,
+        unique_columns: Sequence[str],
+        """
+
+        # We validate the constraint by:
+        #   1. Trying to validate the constraint as is. If this succeeds then
+        #      we're done.
+        #   2. Otherwise, we manually scan the table to remove rows that don't
+        #      match the constraint.
+        #   3. We try re-validating the constraint.
+
+        parsed_progress = ValidateConstraintProgress.parse_obj(progress)
+
+        if parsed_progress.state == ValidateConstraintProgress.State.check:
+            return_columns = ", ".join(unique_columns)
+            order_columns = ", ".join(unique_columns)
+
+            where_clause = ""
+            args: List[Any] = []
+            if parsed_progress.lower_bound:
+                where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
+                args.extend(parsed_progress.lower_bound)
+
+            args.append(batch_size)
+
+            sql = f"""
+                SELECT
+                    {return_columns},
+                    {constraint.make_check_clause(table)} AS check
+                FROM {table}
+                {where_clause}
+                ORDER BY {order_columns}
+                LIMIT ?
+            """
+
+            def validate_constraint_in_background_check(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql, args)
+                rows = txn.fetchall()
+
+                new_progress = parsed_progress.copy()
+
+                if not rows:
+                    new_progress.state = ValidateConstraintProgress.State.validate
+                    self._background_update_progress_txn(
+                        txn, update_name, new_progress.dict()
+                    )
+                    return
+
+                new_progress.lower_bound = rows[-1][:-1]
+
+                to_delete = [row[:-1] for row in rows if not row[-1]]
+
+                if to_delete:
+                    logger.warning(
+                        "Deleting %d rows that do not pass new constraint",
+                        len(to_delete),
+                    )
+
+                    self.db_pool.simple_delete_many_batch_txn(
+                        txn, table=table, keys=unique_columns, values=to_delete
+                    )
+
+                self._background_update_progress_txn(
+                    txn, update_name, new_progress.dict()
+                )
+
+            await self.db_pool.runInteraction(
+                "validate_constraint_in_background_check",
+                validate_constraint_in_background_check,
+            )
+
+            return batch_size
+
+        elif parsed_progress.state == ValidateConstraintProgress.State.validate:
+            sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"
+
+            def validate_constraint_in_background_validate(
+                txn: "LoggingTransaction",
+            ) -> None:
+                txn.execute(sql)
+
+            try:
+                await self.db_pool.runInteraction(
+                    "validate_constraint_in_background_validate",
+                    validate_constraint_in_background_validate,
+                )
+
+                await self._end_background_update(update_name)
+            except self.db_pool.engine.module.IntegrityError as e:
+                # If we get an integrity error here, then we go back and recheck the table.
+                logger.warning("Integrity error when validating constraint: %s", e)
+                await self._background_update_progress(
+                    update_name,
+                    ValidateConstraintProgress(
+                        state=ValidateConstraintProgress.State.check
+                    ).dict(),
+                )
+
+            return batch_size
+        else:
+            raise Exception(
+                f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background"
+            )
+
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
 
@@ -795,3 +1045,86 @@ class BackgroundUpdater:
             keyvalues={"update_name": update_name},
             updatevalues={"progress_json": progress_json},
         )
+
+
+def run_validate_constraint_and_delete_rows_schema_delta(
+    txn: "LoggingTransaction",
+    ordering: int,
+    update_name: str,
+    table: str,
+    constraint_name: str,
+    constraint: Constraint,
+    sqlite_table_name: str,
+    sqlite_table_schema: str,
+) -> None:
+    """Runs a schema delta to add a constraint to the table. This should be run
+    in a schema delta file.
+
+    For PostgreSQL the constraint is added and validated in the background.
+
+    For SQLite the table is recreated and data copied across immediately. This
+    is done by the caller passing in a script to create the new table. Note that
+    table indexes and triggers are copied over automatically.
+
+    There must be a corresponding call to
+    `register_background_validate_constraint_and_delete_rows` to register the
+    background update in one of the data store classes.
+
+    Attributes:
+        txn ordering, update_name: For adding a row to background_updates table.
+        table: The table to add constraint to. constraint_name: The name of the
+        new constraint constraint: A `Constraint` object describing the
+        constraint sqlite_table_name: For SQLite the name of the empty copy of
+        table sqlite_table_schema: A SQL script for creating the above table.
+    """
+
+    if isinstance(txn.database_engine, PostgresEngine):
+        # For postgres we can just add the constraint and mark it as NOT VALID,
+        # and then insert a background update to go and check the validity in
+        # the background.
+        txn.execute(
+            f"""
+            ALTER TABLE {table}
+            ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
+            NOT VALID
+            """
+        )
+
+        txn.execute(
+            "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
+            (ordering, update_name),
+        )
+    else:
+        # For SQLite, we:
+        #   1. fetch all indexes/triggers/etc related to the table
+        #   2. create an empty copy of the table
+        #   3. copy across the rows (that satisfy the check)
+        #   4. replace the old table with the new able.
+        #   5. add back all the indexes/triggers/etc
+
+        # Fetch the indexes/triggers/etc. Note that `sql` column being null is
+        # due to indexes being auto created based on the class definition (e.g.
+        # PRIMARY KEY), and so don't need to be recreated.
+        txn.execute(
+            """
+            SELECT sql FROM sqlite_master
+            WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL
+            """,
+            (table,),
+        )
+        extras = [row[0] for row in txn]
+
+        txn.execute(sqlite_table_schema)
+
+        sql = f"""
+            INSERT INTO {sqlite_table_name} SELECT * FROM {table}
+            WHERE {constraint.make_check_clause(table)}
+        """
+
+        txn.execute(sql)
+
+        txn.execute(f"DROP TABLE {table}")
+        txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")
+
+        for extra in extras:
+            txn.execute(extra)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 7e49ae11bc..a1c8fb0f46 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2313,6 +2313,43 @@ class DatabasePool:
 
         return txn.rowcount
 
+    @staticmethod
+    def simple_delete_many_batch_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+    ) -> None:
+        """Executes a DELETE query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+        """
+
+        if isinstance(txn.database_engine, PostgresEngine):
+            # We use `execute_values` as it can be a lot faster than `execute_batch`,
+            # but it's only available on postgres.
+            sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % (
+                table,
+                ", ".join(k for k in keys),
+            )
+
+            txn.execute_values(sql, values, fetch=False)
+        else:
+            sql = "DELETE FROM %s WHERE (%s) = (%s)" % (
+                table,
+                ", ".join(k for k in keys),
+                ", ".join("?" for _ in keys),
+            )
+
+            txn.execute_batch(sql, values)
+
     def get_cache_dict(
         self,
         db_conn: LoggingDatabaseConnection,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8b6e3c1dc7..dabe603c8c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -38,6 +38,7 @@ from synapse.events import EventBase, make_event_from_dict
 from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.background_updates import ForeignKeyConstraint
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -140,6 +141,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
+        if isinstance(self.database_engine, PostgresEngine):
+            self.db_pool.updates.register_background_validate_constraint_and_delete_rows(
+                update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+                table="event_forward_extremities",
+                constraint_name="event_forward_extremities_event_id",
+                constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+                unique_columns=("event_id", "room_id"),
+            )
+
     async def get_auth_chain(
         self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5c9db7554e..2b83a69426 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -415,12 +415,6 @@ class PersistEventsStore:
                 backfilled=False,
             )
 
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremities,
-            max_stream_order=max_stream_order,
-        )
-
         # Ensure that we don't have the same event twice.
         events_and_contexts = self._filter_events_and_contexts_for_duplicates(
             events_and_contexts
@@ -439,6 +433,12 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremities,
+            max_stream_order=max_stream_order,
+        )
+
         self._persist_transaction_ids_txn(txn, events_and_contexts)
 
         # Insert into event_to_state_groups.
diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
new file mode 100644
index 0000000000..f12e2a8f3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
@@ -0,0 +1,51 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This migration adds foreign key constraint to `event_forward_extremities` table.
+"""
+from synapse.storage.background_updates import (
+    ForeignKeyConstraint,
+    run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import BaseDatabaseEngine
+
+FORWARD_EXTREMITIES_TABLE_SCHEMA = """
+    CREATE TABLE event_forward_extremities2(
+        event_id TEXT NOT NULL,
+        room_id TEXT NOT NULL,
+        UNIQUE (event_id, room_id),
+        CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id)
+    )
+"""
+
+
+def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
+    run_validate_constraint_and_delete_rows_schema_delta(
+        cur,
+        ordering=7803,
+        update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
+        table="event_forward_extremities",
+        constraint_name="event_forward_extremities_event_id",
+        constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
+        sqlite_table_name="event_forward_extremities2",
+        sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA,
+    )
+
+    # We can't add a similar constraint to `event_backward_extremities` as the
+    # events in there don't exist in the `events` table and `event_edges`
+    # doesn't have a unique constraint on `prev_event_id` (so we can't make a
+    # foreign key point to it).
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fd619b64d4..6ca546f3f7 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -20,7 +20,14 @@ from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
-from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.background_updates import (
+    BackgroundUpdater,
+    ForeignKeyConstraint,
+    NotNullConstraint,
+    run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -404,3 +411,221 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         self.pump()
         self._update_ctx_manager.__aexit__.assert_called()
         self.get_success(do_update_d)
+
+
+class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
+    """Tests the validate contraint and delete background handlers."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
+        # the base test class should have run the real bg updates for us
+        self.assertTrue(
+            self.get_success(self.updates.has_completed_background_updates())
+        )
+
+        self.store = self.hs.get_datastores().main
+
+    def test_not_null_constraint(self) -> None:
+        # Create the initial tables, where we have some invalid data.
+        """Tests adding a not null constraint."""
+        table_sql = """
+            CREATE TABLE test_constraint(
+                a INT PRIMARY KEY,
+                b INT
+            );
+        """
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_not_null_constraint", lambda _: None, table_sql
+            )
+        )
+
+        # We add an index so that we can check that its correctly recreated when
+        # using SQLite.
+        index_sql = "CREATE INDEX test_index ON test_constraint(a)"
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_not_null_constraint", lambda _: None, index_sql
+            )
+        )
+
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+        )
+
+        # Now lets do the migration
+
+        table2_sqlite = """
+            CREATE TABLE test_constraint2(
+                a INT PRIMARY KEY,
+                b INT,
+                CONSTRAINT test_constraint_name CHECK (b is NOT NULL)
+            );
+        """
+
+        def delta(txn: LoggingTransaction) -> None:
+            run_validate_constraint_and_delete_rows_schema_delta(
+                txn,
+                ordering=1000,
+                update_name="test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=NotNullConstraint("b"),
+                sqlite_table_name="test_constraint2",
+                sqlite_table_schema=table2_sqlite,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_not_null_constraint",
+                delta,
+            )
+        )
+
+        if isinstance(self.store.database_engine, PostgresEngine):
+            # Postgres uses a background update
+            self.updates.register_background_validate_constraint_and_delete_rows(
+                "test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=NotNullConstraint("b"),
+                unique_columns=["a"],
+            )
+
+            # Tell the DataStore that it hasn't finished all updates yet
+            self.store.db_pool.updates._all_done = False
+
+            # Now let's actually drive the updates to completion
+            self.wait_for_background_updates()
+
+        # Check the correct values are in the new table.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="test_constraint",
+                keyvalues={},
+                retcols=("a", "b"),
+            )
+        )
+
+        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+        # And check that invalid rows get correctly rejected.
+        self.get_failure(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}),
+            exc=self.store.database_engine.module.IntegrityError,
+        )
+
+        # Check the index is still there for SQLite.
+        if isinstance(self.store.database_engine, Sqlite3Engine):
+            # Ensure the index exists in the schema.
+            self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table="sqlite_master",
+                    keyvalues={"tbl_name": "test_constraint"},
+                    retcol="name",
+                )
+            )
+
+    def test_foreign_constraint(self) -> None:
+        """Tests adding a not foreign key constraint."""
+
+        # Create the initial tables, where we have some invalid data.
+        base_sql = """
+            CREATE TABLE base_table(
+                b INT PRIMARY KEY
+            );
+        """
+
+        table_sql = """
+            CREATE TABLE test_constraint(
+                a INT PRIMARY KEY,
+                b INT NOT NULL
+            );
+        """
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_foreign_key_constraint", lambda _: None, base_sql
+            )
+        )
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_foreign_key_constraint", lambda _: None, table_sql
+            )
+        )
+
+        self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1}))
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2})
+        )
+        self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3}))
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+        )
+
+        table2_sqlite = """
+            CREATE TABLE test_constraint2(
+                a INT PRIMARY KEY,
+                b INT NOT NULL,
+                CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b)
+            );
+        """
+
+        def delta(txn: LoggingTransaction) -> None:
+            run_validate_constraint_and_delete_rows_schema_delta(
+                txn,
+                ordering=1000,
+                update_name="test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
+                sqlite_table_name="test_constraint2",
+                sqlite_table_schema=table2_sqlite,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_foreign_key_constraint",
+                delta,
+            )
+        )
+
+        if isinstance(self.store.database_engine, PostgresEngine):
+            # Postgres uses a background update
+            self.updates.register_background_validate_constraint_and_delete_rows(
+                "test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
+                unique_columns=["a"],
+            )
+
+            # Tell the DataStore that it hasn't finished all updates yet
+            self.store.db_pool.updates._all_done = False
+
+            # Now let's actually drive the updates to completion
+            self.wait_for_background_updates()
+
+        # Check the correct values are in the new table.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="test_constraint",
+                keyvalues={},
+                retcols=("a", "b"),
+            )
+        )
+        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+        # And check that invalid rows get correctly rejected.
+        self.get_failure(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}),
+            exc=self.store.database_engine.module.IntegrityError,
+        )
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0f3b0744f1..9c151a5e62 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -20,6 +20,7 @@ from parameterized import parameterized
 
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.constants import EventTypes
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -98,8 +99,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         room2 = "#room2"
         room3 = "#room3"
 
-        def insert_event(txn: Cursor, i: int, room_id: str) -> None:
+        def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None:
             event_id = "$event_%i:local" % i
+
+            # We need to insert into events table to get around the foreign key constraint.
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                table="events",
+                values={
+                    "instance_name": "master",
+                    "stream_ordering": self.store._stream_id_gen.get_next_txn(txn),
+                    "topological_ordering": 1,
+                    "depth": 1,
+                    "event_id": event_id,
+                    "room_id": room_id,
+                    "type": EventTypes.Message,
+                    "processed": True,
+                    "outlier": False,
+                    "origin_server_ts": 0,
+                    "received_ts": 0,
+                    "sender": "@user:local",
+                    "contains_url": False,
+                    "state_key": None,
+                    "rejection_reason": None,
+                },
+            )
+
             txn.execute(
                 (
                     "INSERT INTO event_forward_extremities (room_id, event_id) "
@@ -113,10 +138,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 self.store.db_pool.runInteraction("insert", insert_event, i, room1)
             )
             self.get_success(
-                self.store.db_pool.runInteraction("insert", insert_event, i, room2)
+                self.store.db_pool.runInteraction(
+                    "insert", insert_event, i + 100, room2
+                )
             )
             self.get_success(
-                self.store.db_pool.runInteraction("insert", insert_event, i, room3)
+                self.store.db_pool.runInteraction(
+                    "insert", insert_event, i + 200, room3
+                )
             )
 
         # Test simple case