summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py57
1 files changed, 34 insertions, 23 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 670387b04a..30e6eac8db 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,8 +17,11 @@ import logging
 from synapse.api.errors import StoreError
 from synapse.api.events.utils import prune_event
 from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
 from syutil.base64util import encode_base64
 
+from twisted.internet import defer
+
 import collections
 import copy
 import json
@@ -84,32 +87,40 @@ class SQLBaseStore(object):
         self.event_factory = hs.get_event_factory()
         self._clock = hs.get_clock()
 
+    @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
+        current_context = LoggingContext.current_context()
         def inner_func(txn, *args, **kwargs):
-            start = time.clock() * 1000
-            txn_id = SQLBaseStore._TXN_ID
-
-            # We don't really need these to be unique, so lets stop it from
-            # growing really large.
-            self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
-
-            name = "%s-%x" % (desc, txn_id, )
-
-            transaction_logger.debug("[TXN START] {%s}", name)
-            try:
-                return func(LoggingTransaction(txn, name), *args, **kwargs)
-            except:
-                logger.exception("[TXN FAIL] {%s}", name)
-                raise
-            finally:
-                end = time.clock() * 1000
-                transaction_logger.debug(
-                    "[TXN END] {%s} %f",
-                    name, end - start
-                )
+            with LoggingContext("runInteraction") as context:
+                current_context.copy_to(context)
+                start = time.clock() * 1000
+                txn_id = SQLBaseStore._TXN_ID
+
+                # We don't really need these to be unique, so lets stop it from
+                # growing really large.
+                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+                name = "%s-%x" % (desc, txn_id, )
+
+                transaction_logger.debug("[TXN START] {%s}", name)
+                try:
+                    return func(LoggingTransaction(txn, name), *args, **kwargs)
+                except:
+                    logger.exception("[TXN FAIL] {%s}", name)
+                    raise
+                finally:
+                    end = time.clock() * 1000
+                    transaction_logger.debug(
+                        "[TXN END] {%s} %f",
+                        name, end - start
+                    )
 
-        return self._db_pool.runInteraction(inner_func, *args, **kwargs)
+        with PreserveLoggingContext():
+            result = yield self._db_pool.runInteraction(
+                inner_func, *args, **kwargs
+            )
+        defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
         """Converts a SQL cursor into an list of dicts.
@@ -177,7 +188,7 @@ class SQLBaseStore(object):
         )
 
         logger.debug(
-            "[SQL] %s  Args=%s Func=%s",
+            "[SQL] %s Args=%s",
             sql, values.values(),
         )