diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d73670f9f2..7cbe390b15 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -308,47 +308,44 @@ def preserve_context_over_deferred(deferred, context=None):
return d
-def reset_context_after_deferred(deferred):
- """If the deferred is incomplete, add a callback which will reset the
- context.
-
- This is useful when you want to fire off a deferred, but don't want to
- wait for it to complete. (The deferred will restore the current log context
- when it completes, so if you don't do anything, it will leak log context.)
-
- (If this feels asymmetric, consider it this way: we are effectively forking
- a new thread of execution. We are probably currently within a
- ``with LoggingContext()`` block, which is supposed to have a single entry
- and exit point. But by spawning off another deferred, we are effectively
- adding a new exit point.)
+def preserve_fn(f):
+ """Wraps a function, to ensure that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the funtion completes.
- Args:
- deferred (defer.Deferred): deferred
+ Useful for wrapping functions that return a deferred which you don't yield
+ on.
"""
def reset_context(result):
LoggingContext.set_current_context(LoggingContext.sentinel)
return result
- if not deferred.called:
- deferred.addBoth(reset_context)
-
-
-def preserve_fn(f):
- """Ensures that function is called with correct context and that context is
- restored after return. Useful for wrapping functions that return a deferred
- which you don't yield on.
- """
+ # XXX: why is this here rather than inside g? surely we want to preserve
+ # the context from the time the function was called, not when it was
+ # wrapped?
current = LoggingContext.current_context()
def g(*args, **kwargs):
- with PreserveLoggingContext(current):
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred):
- return preserve_context_over_deferred(
- res, context=LoggingContext.sentinel
- )
- else:
- return res
+ res = f(*args, **kwargs)
+ if isinstance(res, defer.Deferred) and not res.called:
+ # The function will have reset the context before returning, so
+ # we need to restore it now.
+ LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(reset_context)
+ return res
return g
|