summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12199.misc1
-rw-r--r--synapse/storage/database.py61
-rw-r--r--tests/storage/test_database.py58
3 files changed, 96 insertions, 24 deletions
diff --git a/changelog.d/12199.misc b/changelog.d/12199.misc
new file mode 100644
index 0000000000..16dec1d26d
--- /dev/null
+++ b/changelog.d/12199.misc
@@ -0,0 +1 @@
+Handle cancellation in `DatabasePool.runInteraction()`.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..9749f0c06e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
+from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        after_callbacks: List[_CallbackListEntry] = []
-        exception_callbacks: List[_CallbackListEntry] = []
 
-        if not current_context():
-            logger.warning("Starting db txn '%s' from sentinel context", desc)
+        async def _runInteraction() -> R:
+            after_callbacks: List[_CallbackListEntry] = []
+            exception_callbacks: List[_CallbackListEntry] = []
 
-        try:
-            with opentracing.start_active_span(f"db.{desc}"):
-                result = await self.runWithConnection(
-                    self.new_transaction,
-                    desc,
-                    after_callbacks,
-                    exception_callbacks,
-                    func,
-                    *args,
-                    db_autocommit=db_autocommit,
-                    isolation_level=isolation_level,
-                    **kwargs,
-                )
+            if not current_context():
+                logger.warning("Starting db txn '%s' from sentinel context", desc)
 
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except Exception:
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
+            try:
+                with opentracing.start_active_span(f"db.{desc}"):
+                    result = await self.runWithConnection(
+                        self.new_transaction,
+                        desc,
+                        after_callbacks,
+                        exception_callbacks,
+                        func,
+                        *args,
+                        db_autocommit=db_autocommit,
+                        isolation_level=isolation_level,
+                        **kwargs,
+                    )
 
-        return cast(R, result)
+                for after_callback, after_args, after_kwargs in after_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+
+                return cast(R, result)
+            except Exception:
+                for after_callback, after_args, after_kwargs in exception_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+                raise
+
+        # To handle cancellation, we ensure that `after_callback`s and
+        # `exception_callback`s are always run, since the transaction will complete
+        # on another thread regardless of cancellation.
+        #
+        # We also wait until everything above is done before releasing the
+        # `CancelledError`, so that logging contexts won't get used after they have been
+        # finished.
+        return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
 
     async def runWithConnection(
         self,
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index ae13bed086..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -15,6 +15,8 @@
 from typing import Callable, Tuple
 from unittest.mock import Mock, call
 
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
@@ -124,3 +126,59 @@ class CallbacksTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(after_callback.call_count, 2)  # no additional calls
         exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+    def test_after_callback(self) -> None:
+        """Test that the after callback is called when a transaction succeeds."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_called_once_with(123, 456, extra=789)
+        exception_callback.assert_not_called()
+
+    def test_exception_callback(self) -> None:
+        """Test that the exception callback is called when a transaction fails."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+            # Simulate a retryable failure on every attempt.
+            raise self.db_pool.engine.module.OperationalError()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_has_calls(
+            [
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+            ]
+        )
+        self.assertEqual(exception_callback.call_count, 6)  # no additional calls