summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py79
1 files changed, 47 insertions, 32 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..2fbebd4907 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -48,16 +48,16 @@ class LoggingTransaction(object):
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
     __slots__ = [
-        "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+        "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
     ]
 
     def __init__(self, txn, name, database_engine, after_callbacks,
-                 final_callbacks):
+                 exception_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "final_callbacks", final_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
 
     def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
@@ -66,8 +66,8 @@ class LoggingTransaction(object):
         """
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_finally(self, callback, *args, **kwargs):
-        self.final_callbacks.append((callback, args, kwargs))
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -215,7 +215,7 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+    def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
                          logging_context, func, *args, **kwargs):
         start = time.time() * 1000
         txn_id = self._TXN_ID
@@ -236,7 +236,7 @@ class SQLBaseStore(object):
                     txn = conn.cursor()
                     txn = LoggingTransaction(
                         txn, name, self.database_engine, after_callbacks,
-                        final_callbacks,
+                        exception_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
                     conn.commit()
@@ -291,52 +291,66 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
+        """Starts a transaction on the database and runs a given function
 
-        start_time = time.time() * 1000
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
 
-        after_callbacks = []
-        final_callbacks = []
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
 
-        def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runInteraction") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+        Returns:
+            Deferred: The result of func
+        """
+        current_context = LoggingContext.current_context()
 
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
+        after_callbacks = []
+        exception_callbacks = []
 
-                current_context.copy_to(context)
-                return self._new_transaction(
-                    conn, desc, after_callbacks, final_callbacks, current_context,
-                    func, *args, **kwargs
-                )
+        def inner_func(conn, *args, **kwargs):
+            return self._new_transaction(
+                conn, desc, after_callbacks, exception_callbacks, current_context,
+                func, *args, **kwargs
+            )
 
         try:
-            with PreserveLoggingContext():
-                result = yield self._db_pool.runWithConnection(
-                    inner_func, *args, **kwargs
-                )
+            result = yield self.runWithConnection(inner_func, *args, **kwargs)
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
-        finally:
-            for after_callback, after_args, after_kwargs in final_callbacks:
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
                 after_callback(*after_args, **after_kwargs)
+            raise
 
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
         current_context = LoggingContext.current_context()
 
         start_time = time.time() * 1000
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runWithConnection") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+                sched_duration_ms = time.time() * 1000 - start_time
+                sql_scheduling_timer.inc_by(sched_duration_ms)
+                current_context.add_database_scheduled(sched_duration_ms)
 
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
@@ -987,7 +1001,8 @@ class SQLBaseStore(object):
             # __exit__ called after the transaction finishes.
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
-            txn.call_finally(ctx.__exit__, None, None, None)
+            txn.call_on_exception(ctx.__exit__, None, None, None)
+            txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             self._simple_insert_txn(