1 files changed, 52 insertions, 0 deletions
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
import threading
import logging
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context
+
+ if self.current_context is not LoggingContext.sentinel:
+ if self.current_context.parent_context is None:
+ logger.warn(
+ "Restoring dead context: %s",
+ self.current_context,
+ )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+ """Takes a function and invokes it with the given arguments, but removes
+ and restores the current logging context while doing so.
+
+ If the result is a deferred, call preserve_context_over_deferred before
+ returning it.
+ """
+ with PreserveLoggingContext():
+ res = fn(*args, **kwargs)
+
+ if isinstance(res, defer.Deferred):
+ return preserve_context_over_deferred(res)
+ else:
+ return res
+
+
+def preserve_context_over_deferred(deferred):
+ """Given a deferred wrap it such that any callbacks added later to it will
+ be invoked with the current context.
+ """
+ d = defer.Deferred()
+
+ current_context = LoggingContext.current_context()
+
+ def cb(res):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.callback(res)
+ return res
+
+ def eb(failure):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.errback(failure)
+ return res
+
+ if deferred.called:
+ return deferred
+
+ deferred.addCallbacks(cb, eb)
+ return d
|