diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 543cce6b3e..8cd7c89ca2 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import (
DatabasePool,
+ LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@@ -37,6 +38,101 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
self.assertEqual(args, [1, 2])
+class ExecuteScriptTestCase(unittest.HomeserverTestCase):
+ """Tests for `BaseDatabaseEngine.executescript` implementations."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+ self.get_success(
+ self.db_pool.runInteraction(
+ "create",
+ lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
+ )
+ )
+
+ def test_transaction(self) -> None:
+ """Test that all statements are run in a single transaction."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_transaction")
+ self.db_pool.engine.executescript(
+ cur,
+ ";".join(
+ [
+ "INSERT INTO foo (name) VALUES ('transaction test')",
+ # This next statement will fail. When `executescript` is not
+ # transactional, the previous row will be observed later.
+ "INSERT INTO foo (name) VALUES ('transaction test')",
+ ]
+ ),
+ )
+
+ self.get_failure(
+ self.db_pool.runWithConnection(run),
+ self.db_pool.engine.module.IntegrityError,
+ )
+
+ self.assertIsNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "transaction test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ "executescript is not running statements inside a transaction",
+ )
+
+ def test_commit(self) -> None:
+ """Test that the script transaction remains open and can be committed."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_commit")
+ self.db_pool.engine.executescript(
+ cur, "INSERT INTO foo (name) VALUES ('commit test')"
+ )
+ cur.execute("COMMIT")
+
+ self.get_success(self.db_pool.runWithConnection(run))
+
+ self.assertIsNotNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "commit test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ )
+
+ def test_rollback(self) -> None:
+ """Test that the script transaction remains open and can be rolled back."""
+
+ def run(conn: LoggingDatabaseConnection) -> None:
+ cur = conn.cursor(txn_name="test_rollback")
+ self.db_pool.engine.executescript(
+ cur, "INSERT INTO foo (name) VALUES ('rollback test')"
+ )
+ cur.execute("ROLLBACK")
+
+ self.get_success(self.db_pool.runWithConnection(run))
+
+ self.assertIsNone(
+ self.get_success(
+ self.db_pool.simple_select_one_onecol(
+ "foo",
+ keyvalues={"name": "rollback test"},
+ retcol="name",
+ allow_none=True,
+ )
+ ),
+ "executescript is not leaving the script transaction open",
+ )
+
+
class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""
|