diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 670387b04a..30e6eac8db 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,8 +17,11 @@ import logging
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from syutil.base64util import encode_base64
+from twisted.internet import defer
+
import collections
import copy
import json
@@ -84,32 +87,40 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
+ @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
+ current_context = LoggingContext.current_context()
def inner_func(txn, *args, **kwargs):
- start = time.clock() * 1000
- txn_id = SQLBaseStore._TXN_ID
-
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
-
- name = "%s-%x" % (desc, txn_id, )
-
- transaction_logger.debug("[TXN START] {%s}", name)
- try:
- return func(LoggingTransaction(txn, name), *args, **kwargs)
- except:
- logger.exception("[TXN FAIL] {%s}", name)
- raise
- finally:
- end = time.clock() * 1000
- transaction_logger.debug(
- "[TXN END] {%s} %f",
- name, end - start
- )
+ with LoggingContext("runInteraction") as context:
+ current_context.copy_to(context)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
- return self._db_pool.runInteraction(inner_func, *args, **kwargs)
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runInteraction(
+ inner_func, *args, **kwargs
+ )
+ defer.returnValue(result)
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.
@@ -177,7 +188,7 @@ class SQLBaseStore(object):
)
logger.debug(
- "[SQL] %s Args=%s Func=%s",
+ "[SQL] %s Args=%s",
sql, values.values(),
)
|