1 files changed, 52 insertions, 0 deletions
| diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 import threading
 import logging
 
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
         LoggingContext.thread_local.current_context = self.current_context
+
+        if self.current_context is not LoggingContext.sentinel:
+            if self.current_context.parent_context is None:
+                logger.warn(
+                    "Restoring dead context: %s",
+                    self.current_context,
+                )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+    """Takes a function and invokes it with the given arguments, but removes
+    and restores the current logging context while doing so.
+
+    If the result is a deferred, call preserve_context_over_deferred before
+    returning it.
+    """
+    with PreserveLoggingContext():
+        res = fn(*args, **kwargs)
+
+    if isinstance(res, defer.Deferred):
+        return preserve_context_over_deferred(res)
+    else:
+        return res
+
+
+def preserve_context_over_deferred(deferred):
+    """Given a deferred wrap it such that any callbacks added later to it will
+    be invoked with the current context.
+    """
+    d = defer.Deferred()
+
+    current_context = LoggingContext.current_context()
+
+    def cb(res):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.callback(res)
+        return res
+
+    def eb(failure):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.errback(failure)
+        return res
+
+    if deferred.called:
+        return deferred
+
+    deferred.addCallbacks(cb, eb)
+    return d
 |