diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..9749f0c06e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
|