summary refs log tree commit diff
path: root/tests/storage/test_database.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_database.py')
-rw-r--r--tests/storage/test_database.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 543cce6b3e..8cd7c89ca2 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.server import HomeServer
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -37,6 +38,101 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
         self.assertEqual(args, [1, 2])
 
 
+class ExecuteScriptTestCase(unittest.HomeserverTestCase):
+    """Tests for `BaseDatabaseEngine.executescript` implementations."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+        self.get_success(
+            self.db_pool.runInteraction(
+                "create",
+                lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
+            )
+        )
+
+    def test_transaction(self) -> None:
+        """Test that all statements are run in a single transaction."""
+
+        def run(conn: LoggingDatabaseConnection) -> None:
+            cur = conn.cursor(txn_name="test_transaction")
+            self.db_pool.engine.executescript(
+                cur,
+                ";".join(
+                    [
+                        "INSERT INTO foo (name) VALUES ('transaction test')",
+                        # This next statement will fail. When `executescript` is not
+                        # transactional, the previous row will be observed later.
+                        "INSERT INTO foo (name) VALUES ('transaction test')",
+                    ]
+                ),
+            )
+
+        self.get_failure(
+            self.db_pool.runWithConnection(run),
+            self.db_pool.engine.module.IntegrityError,
+        )
+
+        self.assertIsNone(
+            self.get_success(
+                self.db_pool.simple_select_one_onecol(
+                    "foo",
+                    keyvalues={"name": "transaction test"},
+                    retcol="name",
+                    allow_none=True,
+                )
+            ),
+            "executescript is not running statements inside a transaction",
+        )
+
+    def test_commit(self) -> None:
+        """Test that the script transaction remains open and can be committed."""
+
+        def run(conn: LoggingDatabaseConnection) -> None:
+            cur = conn.cursor(txn_name="test_commit")
+            self.db_pool.engine.executescript(
+                cur, "INSERT INTO foo (name) VALUES ('commit test')"
+            )
+            cur.execute("COMMIT")
+
+        self.get_success(self.db_pool.runWithConnection(run))
+
+        self.assertIsNotNone(
+            self.get_success(
+                self.db_pool.simple_select_one_onecol(
+                    "foo",
+                    keyvalues={"name": "commit test"},
+                    retcol="name",
+                    allow_none=True,
+                )
+            ),
+        )
+
+    def test_rollback(self) -> None:
+        """Test that the script transaction remains open and can be rolled back."""
+
+        def run(conn: LoggingDatabaseConnection) -> None:
+            cur = conn.cursor(txn_name="test_rollback")
+            self.db_pool.engine.executescript(
+                cur, "INSERT INTO foo (name) VALUES ('rollback test')"
+            )
+            cur.execute("ROLLBACK")
+
+        self.get_success(self.db_pool.runWithConnection(run))
+
+        self.assertIsNone(
+            self.get_success(
+                self.db_pool.simple_select_one_onecol(
+                    "foo",
+                    keyvalues={"name": "rollback test"},
+                    retcol="name",
+                    allow_none=True,
+                )
+            ),
+            "executescript is not leaving the script transaction open",
+        )
+
+
 class CallbacksTestCase(unittest.HomeserverTestCase):
     """Tests for transaction callbacks."""