diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..2fbebd4907 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -48,16 +48,16 @@ class LoggingTransaction(object):
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks,
- final_callbacks):
+ exception_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "final_callbacks", final_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -66,8 +66,8 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
- def call_finally(self, callback, *args, **kwargs):
- self.final_callbacks.append((callback, args, kwargs))
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -215,7 +215,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+ def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -236,7 +236,7 @@ class SQLBaseStore(object):
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks,
- final_callbacks,
+ exception_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -291,52 +291,66 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
- after_callbacks = []
- final_callbacks = []
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
+ after_callbacks = []
+ exception_callbacks = []
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ def inner_func(conn, *args, **kwargs):
+ return self._new_transaction(
+ conn, desc, after_callbacks, exception_callbacks, current_context,
+ func, *args, **kwargs
+ )
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
- finally:
- for after_callback, after_args, after_kwargs in final_callbacks:
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
+ raise
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
@@ -987,7 +1001,8 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_finally(ctx.__exit__, None, None, None)
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
|