From a4c1fdb44a16471964ed6a347be6a191102f5c07 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 9 Mar 2022 18:45:21 +0000 Subject: Remove dead code in `tests/storage/test_database.py` (#12197) Signed-off-by: Sean Quah --- tests/storage/test_database.py | 16 ---------------- 1 file changed, 16 deletions(-) (limited to 'tests/storage/test_database.py') diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 6fbac0ab14..8597867563 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -13,26 +13,10 @@ # limitations under the License. from synapse.storage.database import make_tuple_comparison_clause -from synapse.storage.engines import BaseDatabaseEngine from tests import unittest -def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: - # returns a DatabaseEngine, circumventing the abc mechanism - # any kwargs are set as attributes on the class before instantiating it - t = type( - "TestBaseDatabaseEngine", - (BaseDatabaseEngine,), - dict(BaseDatabaseEngine.__dict__), - ) - # defeat the abc mechanism - t.__abstractmethods__ = set() - for k, v in kwargs.items(): - setattr(t, k, v) - return t(None, None) - - class TupleComparisonClauseTestCase(unittest.TestCase): def test_native_tuple_comparison(self): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) -- cgit 1.5.1 From dea577998f221297d3ff30bdf904f7147f3c3d8a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 15 Mar 2022 15:40:34 +0000 Subject: Add tests for database transaction callbacks (#12198) Signed-off-by: Sean Quah --- changelog.d/12198.misc | 1 + tests/storage/test_database.py | 104 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12198.misc (limited to 'tests/storage/test_database.py') diff --git a/changelog.d/12198.misc b/changelog.d/12198.misc new file mode 100644 index 0000000000..6b184a9053 --- /dev/null +++ b/changelog.d/12198.misc @@ -0,0 +1 @@ +Add tests for database transaction callbacks. diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 8597867563..ae13bed086 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause +from typing import Callable, Tuple +from unittest.mock import Mock, call + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock from tests import unittest @@ -22,3 +33,94 @@ class TupleComparisonClauseTestCase(unittest.TestCase): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + _test_txn = Mock(side_effect=ZeroDivisionError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + # Always raise an `OperationalError`. + _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + # Raise an `OperationalError` on the first attempt only. + _test_txn = Mock( + side_effect=[self.db_pool.engine.module.OperationalError, None] + ) + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # surprising (#12184). Let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() -- cgit 1.5.1 From 61210567405b1ac7efaa23d5513cc0b443da0a3a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Mar 2022 15:07:41 +0000 Subject: Handle cancellation in `DatabasePool.runInteraction()` (#12199) To handle cancellation, we ensure that `after_callback`s and `exception_callback`s are always run, since the transaction will complete on another thread regardless of cancellation. We also wait until everything is done before releasing the `CancelledError`, so that logging contexts won't get used after they have been finished. Signed-off-by: Sean Quah --- changelog.d/12199.misc | 1 + synapse/storage/database.py | 61 +++++++++++++++++++++++++----------------- tests/storage/test_database.py | 58 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 changelog.d/12199.misc (limited to 'tests/storage/test_database.py') diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc new file mode 100644 index 0000000000..16dec1d26d --- /dev/null +++ b/changelog.d/12199.misc @@ -0,0 +1 @@ +Handle cancellation in `DatabasePool.runInteraction()`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9..9749f0c06e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ class DatabasePool: Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index ae13bed086..a40fc20ef9 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -15,6 +15,8 @@ from typing import Callable, Tuple from unittest.mock import Mock, call +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer @@ -124,3 +126,59 @@ class CallbacksTestCase(unittest.HomeserverTestCase): ) self.assertEqual(after_callback.call_count, 2) # no additional calls exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls -- cgit 1.5.1