summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py76
1 files changed, 47 insertions, 29 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py

index 99802228c9..72fef1533f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -286,13 +288,17 @@ class LoggingTransaction: """ if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch # type: ignore + from psycopg2.extras import execute_batch - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + self._do_execute( + lambda the_sql: execute_batch(self.txn, the_sql, args), sql + ) else: self.executemany(sql, args) - def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]: + def execute_values( + self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True + ) -> List[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when using postgres. @@ -300,10 +306,11 @@ class LoggingTransaction: rows (e.g. INSERTs). """ assert isinstance(self.database_engine, PostgresEngine) - from psycopg2.extras import execute_values # type: ignore + from psycopg2.extras import execute_values return self._do_execute( - lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args + lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch), + sql, ) def execute(self, sql: str, *args: Any) -> None: @@ -732,34 +739,45 @@ class DatabasePool: Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self,