diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fd619b64d4..6ca546f3f7 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -20,7 +20,14 @@ from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
-from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.background_updates import (
+ BackgroundUpdater,
+ ForeignKeyConstraint,
+ NotNullConstraint,
+ run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util import Clock
@@ -404,3 +411,221 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
self.pump()
self._update_ctx_manager.__aexit__.assert_called()
self.get_success(do_update_d)
+
+
+class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
+ """Tests the validate contraint and delete background handlers."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
+
+ self.store = self.hs.get_datastores().main
+
+ def test_not_null_constraint(self) -> None:
+ # Create the initial tables, where we have some invalid data.
+ """Tests adding a not null constraint."""
+ table_sql = """
+ CREATE TABLE test_constraint(
+ a INT PRIMARY KEY,
+ b INT
+ );
+ """
+ self.get_success(
+ self.store.db_pool.execute(
+ "test_not_null_constraint", lambda _: None, table_sql
+ )
+ )
+
+ # We add an index so that we can check that its correctly recreated when
+ # using SQLite.
+ index_sql = "CREATE INDEX test_index ON test_constraint(a)"
+ self.get_success(
+ self.store.db_pool.execute(
+ "test_not_null_constraint", lambda _: None, index_sql
+ )
+ )
+
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None})
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+ )
+
+ # Now lets do the migration
+
+ table2_sqlite = """
+ CREATE TABLE test_constraint2(
+ a INT PRIMARY KEY,
+ b INT,
+ CONSTRAINT test_constraint_name CHECK (b is NOT NULL)
+ );
+ """
+
+ def delta(txn: LoggingTransaction) -> None:
+ run_validate_constraint_and_delete_rows_schema_delta(
+ txn,
+ ordering=1000,
+ update_name="test_bg_update",
+ table="test_constraint",
+ constraint_name="test_constraint_name",
+ constraint=NotNullConstraint("b"),
+ sqlite_table_name="test_constraint2",
+ sqlite_table_schema=table2_sqlite,
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test_not_null_constraint",
+ delta,
+ )
+ )
+
+ if isinstance(self.store.database_engine, PostgresEngine):
+ # Postgres uses a background update
+ self.updates.register_background_validate_constraint_and_delete_rows(
+ "test_bg_update",
+ table="test_constraint",
+ constraint_name="test_constraint_name",
+ constraint=NotNullConstraint("b"),
+ unique_columns=["a"],
+ )
+
+ # Tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ # Check the correct values are in the new table.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ )
+
+ self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+ # And check that invalid rows get correctly rejected.
+ self.get_failure(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}),
+ exc=self.store.database_engine.module.IntegrityError,
+ )
+
+ # Check the index is still there for SQLite.
+ if isinstance(self.store.database_engine, Sqlite3Engine):
+ # Ensure the index exists in the schema.
+ self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="sqlite_master",
+ keyvalues={"tbl_name": "test_constraint"},
+ retcol="name",
+ )
+ )
+
+ def test_foreign_constraint(self) -> None:
+ """Tests adding a not foreign key constraint."""
+
+ # Create the initial tables, where we have some invalid data.
+ base_sql = """
+ CREATE TABLE base_table(
+ b INT PRIMARY KEY
+ );
+ """
+
+ table_sql = """
+ CREATE TABLE test_constraint(
+ a INT PRIMARY KEY,
+ b INT NOT NULL
+ );
+ """
+ self.get_success(
+ self.store.db_pool.execute(
+ "test_foreign_key_constraint", lambda _: None, base_sql
+ )
+ )
+ self.get_success(
+ self.store.db_pool.execute(
+ "test_foreign_key_constraint", lambda _: None, table_sql
+ )
+ )
+
+ self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1}))
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2})
+ )
+ self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3}))
+ self.get_success(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+ )
+
+ table2_sqlite = """
+ CREATE TABLE test_constraint2(
+ a INT PRIMARY KEY,
+ b INT NOT NULL,
+ CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b)
+ );
+ """
+
+ def delta(txn: LoggingTransaction) -> None:
+ run_validate_constraint_and_delete_rows_schema_delta(
+ txn,
+ ordering=1000,
+ update_name="test_bg_update",
+ table="test_constraint",
+ constraint_name="test_constraint_name",
+ constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
+ sqlite_table_name="test_constraint2",
+ sqlite_table_schema=table2_sqlite,
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test_foreign_key_constraint",
+ delta,
+ )
+ )
+
+ if isinstance(self.store.database_engine, PostgresEngine):
+ # Postgres uses a background update
+ self.updates.register_background_validate_constraint_and_delete_rows(
+ "test_bg_update",
+ table="test_constraint",
+ constraint_name="test_constraint_name",
+ constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
+ unique_columns=["a"],
+ )
+
+ # Tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ # Check the correct values are in the new table.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ )
+ self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+ # And check that invalid rows get correctly rejected.
+ self.get_failure(
+ self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}),
+ exc=self.store.database_engine.module.IntegrityError,
+ )
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0f3b0744f1..9c151a5e62 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -20,6 +20,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EventTypes
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -98,8 +99,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
room2 = "#room2"
room3 = "#room3"
- def insert_event(txn: Cursor, i: int, room_id: str) -> None:
+ def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i
+
+ # We need to insert into events table to get around the foreign key constraint.
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "instance_name": "master",
+ "stream_ordering": self.store._stream_id_gen.get_next_txn(txn),
+ "topological_ordering": 1,
+ "depth": 1,
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": EventTypes.Message,
+ "processed": True,
+ "outlier": False,
+ "origin_server_ts": 0,
+ "received_ts": 0,
+ "sender": "@user:local",
+ "contains_url": False,
+ "state_key": None,
+ "rejection_reason": None,
+ },
+ )
+
txn.execute(
(
"INSERT INTO event_forward_extremities (room_id, event_id) "
@@ -113,10 +138,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
- self.store.db_pool.runInteraction("insert", insert_event, i, room2)
+ self.store.db_pool.runInteraction(
+ "insert", insert_event, i + 100, room2
+ )
)
self.get_success(
- self.store.db_pool.runInteraction("insert", insert_event, i, room3)
+ self.store.db_pool.runInteraction(
+ "insert", insert_event, i + 200, room3
+ )
)
# Test simple case
|