summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyclient.py17
-rw-r--r--synapse/federation/federation_server.py46
-rw-r--r--synapse/handlers/_base.py11
-rw-r--r--synapse/handlers/events.py8
-rw-r--r--synapse/handlers/federation.py29
-rw-r--r--synapse/handlers/presence.py44
-rw-r--r--synapse/handlers/profile.py16
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/http/client.py6
-rw-r--r--synapse/http/matrixfederationclient.py32
-rw-r--r--synapse/http/server.py6
-rw-r--r--synapse/notifier.py19
-rw-r--r--synapse/storage/_base.py11
-rw-r--r--synapse/util/__init__.py11
-rw-r--r--synapse/util/async.py6
-rw-r--r--synapse/util/distributor.py45
-rw-r--r--synapse/util/logcontext.py52
17 files changed, 212 insertions, 151 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 4911f0896b..24f15f3154 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
 from twisted.internet.protocol import Factory
 from twisted.internet import defer, reactor
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+    preserve_context_over_fn, preserve_context_over_deferred
+)
 import simplejson as json
 import logging
 
@@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     for i in range(5):
         try:
-            with PreserveLoggingContext():
-                protocol = yield endpoint.connect(factory)
-                server_response, server_certificate = yield protocol.remote_key
-                defer.returnValue((server_response, server_certificate))
-                return
+            protocol = yield preserve_context_over_fn(
+                endpoint.connect, factory
+            )
+            server_response, server_certificate = yield preserve_context_over_deferred(
+                protocol.remote_key
+            )
+            defer.returnValue((server_response, server_certificate))
+            return
         except SynapseKeyClientError as e:
             logger.exception("Error getting key for %r" % (server_name,))
             if e.status.startswith("4"):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2b46188c91..cd79e23f4b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,7 +20,6 @@ from .federation_base import FederationBase
 from .units import Transaction, Edu
 
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
 import synapse.metrics
 
@@ -123,29 +122,28 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        with PreserveLoggingContext():
-            results = []
-
-            for pdu in pdu_list:
-                d = self._handle_new_pdu(transaction.origin, pdu)
-
-                try:
-                    yield d
-                    results.append({})
-                except FederationError as e:
-                    self.send_failure(e, transaction.origin)
-                    results.append({"error": str(e)})
-                except Exception as e:
-                    results.append({"error": str(e)})
-                    logger.exception("Failed to handle PDU")
-
-            if hasattr(transaction, "edus"):
-                for edu in [Edu(**x) for x in transaction.edus]:
-                    self.received_edu(
-                        transaction.origin,
-                        edu.edu_type,
-                        edu.content
-                    )
+        results = []
+
+        for pdu in pdu_list:
+            d = self._handle_new_pdu(transaction.origin, pdu)
+
+            try:
+                yield d
+                results.append({})
+            except FederationError as e:
+                self.send_failure(e, transaction.origin)
+                results.append({"error": str(e)})
+            except Exception as e:
+                results.append({"error": str(e)})
+                logger.exception("Failed to handle PDU")
+
+        if hasattr(transaction, "edus"):
+            for edu in [Edu(**x) for x in transaction.edus]:
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 4b3f4eadab..ddc5c21e7d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.api.constants import Membership, EventTypes
 from synapse.types import UserID
 
+from synapse.util.logcontext import PreserveLoggingContext
+
 import logging
 
 
@@ -137,10 +139,11 @@ class BaseHandler(object):
                     "Failed to get destination from event %s", s.event_id
                 )
 
-        # Don't block waiting on waking up all the listeners.
-        notify_d = self.notifier.on_new_room_event(
-            event, extra_users=extra_users
-        )
+        with PreserveLoggingContext():
+            # Don't block waiting on waking up all the listeners.
+            notify_d = self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
 
         def log_failure(f):
             logger.warn(
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f9f855213b..993d33ba47 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -15,7 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.types import UserID
 from synapse.events.utils import serialize_event
@@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
 
-            with PreserveLoggingContext():
-                events, tokens = yield self.notifier.get_events_for(
-                    auth_user, room_ids, pagin_config, timeout
-                )
+            events, tokens = yield self.notifier.get_events_for(
+                auth_user, room_ids, pagin_config, timeout
+            )
 
             time_now = self.clock.time_msec()
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 85e2757227..77c315c47c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,6 +21,7 @@ from synapse.api.errors import (
     AuthError, FederationError, StoreError,
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
+from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
@@ -197,9 +198,10 @@ class FederationHandler(BaseHandler):
                 target_user = UserID.from_string(target_user_id)
                 extra_users.append(target_user)
 
-            d = self.notifier.on_new_room_event(
-                event, extra_users=extra_users
-            )
+            with PreserveLoggingContext():
+                d = self.notifier.on_new_room_event(
+                    event, extra_users=extra_users
+                )
 
             def log_failure(f):
                 logger.warn(
@@ -431,9 +433,10 @@ class FederationHandler(BaseHandler):
                 auth_events=auth_events,
             )
 
-            d = self.notifier.on_new_room_event(
-                new_event, extra_users=[joinee]
-            )
+            with PreserveLoggingContext():
+                d = self.notifier.on_new_room_event(
+                    new_event, extra_users=[joinee]
+                )
 
             def log_failure(f):
                 logger.warn(
@@ -512,9 +515,10 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        d = self.notifier.on_new_room_event(
-            event, extra_users=extra_users
-        )
+        with PreserveLoggingContext():
+            d = self.notifier.on_new_room_event(
+                event, extra_users=extra_users
+            )
 
         def log_failure(f):
             logger.warn(
@@ -594,9 +598,10 @@ class FederationHandler(BaseHandler):
         )
 
         target_user = UserID.from_string(event.state_key)
-        d = self.notifier.on_new_room_event(
-            event, extra_users=[target_user],
-        )
+        with PreserveLoggingContext():
+            d = self.notifier.on_new_room_event(
+                event, extra_users=[target_user],
+            )
 
         def log_failure(f):
             logger.warn(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 9e15610401..1edab05492 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -18,8 +18,8 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError, AuthError
 from synapse.api.constants import PresenceState
 
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logutils import log_function
 from synapse.types import UserID
 import synapse.metrics
 
@@ -278,15 +278,14 @@ class PresenceHandler(BaseHandler):
         now_online = state["presence"] != PresenceState.OFFLINE
         was_polling = target_user in self._user_cachemap
 
-        with PreserveLoggingContext():
-            if now_online and not was_polling:
-                self.start_polling_presence(target_user, state=state)
-            elif not now_online and was_polling:
-                self.stop_polling_presence(target_user)
+        if now_online and not was_polling:
+            self.start_polling_presence(target_user, state=state)
+        elif not now_online and was_polling:
+            self.stop_polling_presence(target_user)
 
-            # TODO(paul): perform a presence push as part of start/stop poll so
-            #   we don't have to do this all the time
-            self.changed_presencelike_data(target_user, state)
+        # TODO(paul): perform a presence push as part of start/stop poll so
+        #   we don't have to do this all the time
+        self.changed_presencelike_data(target_user, state)
 
     def bump_presence_active_time(self, user, now=None):
         if now is None:
@@ -408,10 +407,10 @@ class PresenceHandler(BaseHandler):
         yield self.store.set_presence_list_accepted(
             observer_user.localpart, observed_user.to_string()
         )
-        with PreserveLoggingContext():
-            self.start_polling_presence(
-                observer_user, target_user=observed_user
-            )
+
+        self.start_polling_presence(
+            observer_user, target_user=observed_user
+        )
 
     @defer.inlineCallbacks
     def deny_presence(self, observed_user, observer_user):
@@ -430,10 +429,9 @@ class PresenceHandler(BaseHandler):
             observer_user.localpart, observed_user.to_string()
         )
 
-        with PreserveLoggingContext():
-            self.stop_polling_presence(
-                observer_user, target_user=observed_user
-            )
+        self.stop_polling_presence(
+            observer_user, target_user=observed_user
+        )
 
     @defer.inlineCallbacks
     def get_presence_list(self, observer_user, accepted=None):
@@ -766,8 +764,7 @@ class PresenceHandler(BaseHandler):
                 if not self._remote_sendmap[user]:
                     del self._remote_sendmap[user]
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(deferreds, consumeErrors=True)
+        yield defer.DeferredList(deferreds, consumeErrors=True)
 
     @defer.inlineCallbacks
     def push_update_to_local_and_remote(self, observed_user, statuscache,
@@ -812,10 +809,11 @@ class PresenceHandler(BaseHandler):
 
     def push_update_to_clients(self, observed_user, users_to_push=[],
                                room_ids=[], statuscache=None):
-        self.notifier.on_new_user_event(
-            users_to_push,
-            room_ids,
-        )
+        with PreserveLoggingContext():
+            self.notifier.on_new_user_event(
+                users_to_push,
+                room_ids,
+            )
 
 
 class PresenceEventSource(object):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index ee2732b848..ffb449d457 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
 from synapse.api.constants import EventTypes, Membership
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import UserID
 
 from ._base import BaseHandler
@@ -154,14 +153,13 @@ class ProfileHandler(BaseHandler):
         if not self.hs.is_mine(user):
             defer.returnValue(None)
 
-        with PreserveLoggingContext():
-            (displayname, avatar_url) = yield defer.gatherResults(
-                [
-                    self.store.get_profile_displayname(user.localpart),
-                    self.store.get_profile_avatar_url(user.localpart),
-                ],
-                consumeErrors=True
-            )
+        (displayname, avatar_url) = yield defer.gatherResults(
+            [
+                self.store.get_profile_displayname(user.localpart),
+                self.store.get_profile_avatar_url(user.localpart),
+            ],
+            consumeErrors=True
+        )
 
         state["displayname"] = displayname
         state["avatar_url"] = avatar_url
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c0b2bd7db0..64fe51aa3e 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from ._base import BaseHandler
 
 from synapse.api.errors import SynapseError, AuthError
+from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import UserID
 
 import logging
@@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler):
         self._latest_room_serial += 1
         self._room_serials[room_id] = self._latest_room_serial
 
-        self.notifier.on_new_user_event(rooms=[room_id])
+        with PreserveLoggingContext():
+            self.notifier.on_new_user_event(rooms=[room_id])
 
 
 class TypingNotificationEventSource(object):
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e8a5dedab4..5b3cefb2dc 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.api.errors import CodeMessageException
+from synapse.util.logcontext import preserve_context_over_fn
 from syutil.jsonutil import encode_canonical_json
 import synapse.metrics
 
@@ -61,7 +62,10 @@ class SimpleHttpClient(object):
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
         outgoing_requests_counter.inc(method)
-        d = self.agent.request(method, *args, **kwargs)
+        d = preserve_context_over_fn(
+            self.agent.request,
+            method, *args, **kwargs
+        )
 
         def _cb(response):
             incoming_responses_counter.inc(method, response.code)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7fa295cad5..c99d237c73 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
 
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util.async import sleep
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
 import synapse.metrics
 
 from syutil.jsonutil import encode_canonical_json
@@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
                 producer = body_callback(method, url_bytes, headers_dict)
 
             try:
-                with PreserveLoggingContext():
-                    request_deferred = self.agent.request(
-                        destination,
-                        endpoint,
-                        method,
-                        path_bytes,
-                        param_bytes,
-                        query_bytes,
-                        Headers(headers_dict),
-                        producer
-                    )
+                request_deferred = preserve_context_over_fn(
+                    self.agent.request,
+                    destination,
+                    endpoint,
+                    method,
+                    path_bytes,
+                    param_bytes,
+                    query_bytes,
+                    Headers(headers_dict),
+                    producer
+                )
 
-                    response = yield self.clock.time_bound_deferred(
-                        request_deferred,
-                        time_out=60,
-                    )
+                response = yield self.clock.time_bound_deferred(
+                    request_deferred,
+                    time_out=60,
+                )
 
                 logger.debug("Got response to %s", method)
                 break
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 93ecbd7589..73efbff4f2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -17,7 +17,7 @@
 from synapse.api.errors import (
     cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
 )
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 import synapse.metrics
 
 from syutil.jsonutil import (
@@ -85,7 +85,9 @@ def request_handler(request_handler):
                     "Received request: %s %s",
                     request.method, request.path
                 )
-                yield request_handler(self, request)
+                d = request_handler(self, request)
+                with PreserveLoggingContext():
+                    yield d
                 code = request.code
             except CodeMessageException as e:
                 code = e.code
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 78eb28e4b2..7282dfd7f3 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,6 @@
 from twisted.internet import defer
 
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import StreamToken
 import synapse.metrics
 
@@ -223,11 +222,10 @@ class Notifier(object):
         def eb(failure):
             logger.exception("Failed to notify listener", failure)
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(
-                [notify(l).addErrback(eb) for l in listeners],
-                consumeErrors=True,
-            )
+        yield defer.DeferredList(
+            [notify(l).addErrback(eb) for l in listeners],
+            consumeErrors=True,
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -298,11 +296,10 @@ class Notifier(object):
                     failure.getTracebackObject())
             )
 
-        with PreserveLoggingContext():
-            yield defer.DeferredList(
-                [notify(l).addErrback(eb) for l in listeners],
-                consumeErrors=True,
-            )
+        yield defer.DeferredList(
+            [notify(l).addErrback(eb) for l in listeners],
+            consumeErrors=True,
+        )
 
     @defer.inlineCallbacks
     def wait_for_events(self, user, rooms, filter, timeout, callback):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9e348590ba..c9fe5a3555 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
@@ -420,10 +420,11 @@ class SQLBaseStore(object):
                     self._txn_perf_counters.update(desc, start, end)
                     sql_txn_timer.inc_by(duration, desc)
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
+
         for after_callback, after_args in after_callbacks:
             after_callback(*after_args)
         defer.returnValue(result)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..fd3eb1f574 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
 
@@ -50,9 +50,12 @@ class Clock(object):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback():
-            LoggingContext.thread_local.current_context = current_context
-            callback()
-        return reactor.callLater(delay, wrapped_callback)
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = current_context
+                callback()
+
+        with PreserveLoggingContext():
+            return reactor.callLater(delay, wrapped_callback)
 
     def cancel_call_later(self, timer):
         timer.cancel()
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..f78395a431 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
 
 
-@defer.inlineCallbacks
 def sleep(seconds):
     d = defer.Deferred()
     reactor.callLater(seconds, d.callback, seconds)
-    with PreserveLoggingContext():
-        yield d
+    return preserve_context_over_deferred(d)
 
 
 def run_on_reactor():
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9d9c350397..5b150cb0e5 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import PreserveLoggingContext
-
 from twisted.internet import defer
 
 import logging
@@ -93,7 +91,6 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -101,24 +98,24 @@ class Signal(object):
 
         Returns a Deferred that will complete when all the observers have
         completed."""
-        with PreserveLoggingContext():
-            deferreds = []
-            for observer in self.observers:
-                d = defer.maybeDeferred(observer, *args, **kwargs)
-
-                def eb(failure):
-                    logger.warning(
-                        "%s signal observer %s failed: %r",
-                        self.name, observer, failure,
-                        exc_info=(
-                            failure.type,
-                            failure.value,
-                            failure.getTracebackObject()))
-                    if not self.suppress_failures:
-                        failure.raiseException()
-                deferreds.append(d.addErrback(eb))
-            results = []
-            for deferred in deferreds:
-                result = yield deferred
-                results.append(result)
-        defer.returnValue(results)
+
+        def eb(failure):
+            logger.warning(
+                "%s signal observer %s failed: %r",
+                self.name, observer, failure,
+                exc_info=(
+                    failure.type,
+                    failure.value,
+                    failure.getTracebackObject()))
+            if not self.suppress_failures:
+                failure.raiseException()
+
+        deferreds = [
+            defer.maybeDeferred(observer, *args, **kwargs)
+            for observer in self.observers
+        ]
+
+        d = defer.gatherResults(deferreds, consumeErrors=True)
+        d.addErrback(eb)
+
+        return d
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 import threading
 import logging
 
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
         LoggingContext.thread_local.current_context = self.current_context
+
+        if self.current_context is not LoggingContext.sentinel:
+            if self.current_context.parent_context is None:
+                logger.warn(
+                    "Restoring dead context: %s",
+                    self.current_context,
+                )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+    """Takes a function and invokes it with the given arguments, but removes
+    and restores the current logging context while doing so.
+
+    If the result is a deferred, call preserve_context_over_deferred before
+    returning it.
+    """
+    with PreserveLoggingContext():
+        res = fn(*args, **kwargs)
+
+    if isinstance(res, defer.Deferred):
+        return preserve_context_over_deferred(res)
+    else:
+        return res
+
+
+def preserve_context_over_deferred(deferred):
+    """Given a deferred wrap it such that any callbacks added later to it will
+    be invoked with the current context.
+    """
+    d = defer.Deferred()
+
+    current_context = LoggingContext.current_context()
+
+    def cb(res):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.callback(res)
+        return res
+
+    def eb(failure):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.errback(failure)
+        return res
+
+    if deferred.called:
+        return deferred
+
+    deferred.addCallbacks(cb, eb)
+    return d