diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..72fef1533f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -286,13 +288,17 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch # type: ignore
+ from psycopg2.extras import execute_batch
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ self._do_execute(
+ lambda the_sql: execute_batch(self.txn, the_sql, args), sql
+ )
else:
self.executemany(sql, args)
- def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
+ def execute_values(
+ self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
+ ) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
@@ -300,10 +306,11 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
- from psycopg2.extras import execute_values # type: ignore
+ from psycopg2.extras import execute_values
return self._do_execute(
- lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
+ lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
+ sql,
)
def execute(self, sql: str, *args: Any) -> None:
@@ -732,34 +739,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
|