diff options
Diffstat (limited to 'tests/storage/test_database.py')
-rw-r--r-- | tests/storage/test_database.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 543cce6b3e..8cd7c89ca2 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -37,6 +38,101 @@ class TupleComparisonClauseTestCase(unittest.TestCase): self.assertEqual(args, [1, 2]) +class ExecuteScriptTestCase(unittest.HomeserverTestCase): + """Tests for `BaseDatabaseEngine.executescript` implementations.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + self.get_success( + self.db_pool.runInteraction( + "create", + lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"), + ) + ) + + def test_transaction(self) -> None: + """Test that all statements are run in a single transaction.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_transaction") + self.db_pool.engine.executescript( + cur, + ";".join( + [ + "INSERT INTO foo (name) VALUES ('transaction test')", + # This next statement will fail. When `executescript` is not + # transactional, the previous row will be observed later. + "INSERT INTO foo (name) VALUES ('transaction test')", + ] + ), + ) + + self.get_failure( + self.db_pool.runWithConnection(run), + self.db_pool.engine.module.IntegrityError, + ) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "transaction test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not running statements inside a transaction", + ) + + def test_commit(self) -> None: + """Test that the script transaction remains open and can be committed.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_commit") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('commit test')" + ) + cur.execute("COMMIT") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNotNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "commit test"}, + retcol="name", + allow_none=True, + ) + ), + ) + + def test_rollback(self) -> None: + """Test that the script transaction remains open and can be rolled back.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_rollback") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('rollback test')" + ) + cur.execute("ROLLBACK") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "rollback test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not leaving the script transaction open", + ) + + class CallbacksTestCase(unittest.HomeserverTestCase): """Tests for transaction callbacks.""" |