diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..364b927851 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -50,8 +50,10 @@ class Clock(object):
current_context = LoggingContext.current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
- callback()
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ callback()
+
return reactor.callLater(delay, wrapped_callback)
def cancel_call_later(self, timer):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..f78395a431 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
-@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
- with PreserveLoggingContext():
- yield d
+ return preserve_context_over_deferred(d)
def run_on_reactor():
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..192e3f49f0 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
import threading
import logging
@@ -129,3 +131,32 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+ with PreserveLoggingContext():
+ deferred = fn(*args, **kwargs)
+
+ return preserve_context_over_deferred(deferred)
+
+
+def preserve_context_over_deferred(deferred):
+ d = defer.Deferred()
+
+ current_context = LoggingContext.current_context()
+
+ def cb(res):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.callback(res)
+ return res
+
+ def eb(failure):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.errback(failure)
+ return res
+
+ deferred.addCallbacks(cb, eb)
+
+ return d
|