| diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 1135a683af..ff67b1d794 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -318,47 +318,44 @@ def preserve_context_over_deferred(deferred, context=None):
     return d
 
 
-def reset_context_after_deferred(deferred):
-    """If the deferred is incomplete, add a callback which will reset the
-    context.
-
-    This is useful when you want to fire off a deferred, but don't want to
-    wait for it to complete. (The deferred will restore the current log context
-    when it completes, so if you don't do anything, it will leak log context.)
-
-    (If this feels asymmetric, consider it this way: we are effectively forking
-    a new thread of execution. We are probably currently within a
-    ``with LoggingContext()`` block, which is supposed to have a single entry
-    and exit point. But by spawning off another deferred, we are effectively
-    adding a new exit point.)
+def preserve_fn(f):
+    """Wraps a function, to ensure that the current context is restored after
+    return from the function, and that the sentinel context is set once the
+    deferred returned by the funtion completes.
 
-    Args:
-        deferred (defer.Deferred): deferred
+    Useful for wrapping functions that return a deferred which you don't yield
+    on.
     """
     def reset_context(result):
         LoggingContext.set_current_context(LoggingContext.sentinel)
         return result
 
-    if not deferred.called:
-        deferred.addBoth(reset_context)
-
-
-def preserve_fn(f):
-    """Ensures that function is called with correct context and that context is
-    restored after return. Useful for wrapping functions that return a deferred
-    which you don't yield on.
-    """
+    # XXX: why is this here rather than inside g? surely we want to preserve
+    # the context from the time the function was called, not when it was
+    # wrapped?
     current = LoggingContext.current_context()
 
     def g(*args, **kwargs):
-        with PreserveLoggingContext(current):
-            res = f(*args, **kwargs)
-            if isinstance(res, defer.Deferred):
-                return preserve_context_over_deferred(
-                    res, context=LoggingContext.sentinel
-                )
-            else:
-                return res
+        res = f(*args, **kwargs)
+        if isinstance(res, defer.Deferred) and not res.called:
+            # The function will have reset the context before returning, so
+            # we need to restore it now.
+            LoggingContext.set_current_context(current)
+
+            # The original context will be restored when the deferred
+            # completes, but there is nothing waiting for it, so it will
+            # get leaked into the reactor or some other function which
+            # wasn't expecting it. We therefore need to reset the context
+            # here.
+            #
+            # (If this feels asymmetric, consider it this way: we are
+            # effectively forking a new thread of execution. We are
+            # probably currently within a ``with LoggingContext()`` block,
+            # which is supposed to have a single entry and exit point. But
+            # by spawning off another deferred, we are effectively
+            # adding a new exit point.)
+            res.addBoth(reset_context)
+        return res
     return g
 |