diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..6747f29e6a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import inspect
import logging
import threading
import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
if not isinstance(deferred, defer.Deferred):
return deferred
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_await(self):
+ # an async function which retuns an incomplete coroutine, but doesn't
+ # follow the synapse rules.
+
+ async def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ await d
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
|