diff --git a/changelog.d/15650.misc b/changelog.d/15650.misc
new file mode 100644
index 0000000000..9bbad113e1
--- /dev/null
+++ b/changelog.d/15650.misc
@@ -0,0 +1 @@
+Add support for tracing functions which return `Awaitable`s.
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c70eee649c..75217e3f45 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,6 +171,7 @@ from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
ContextManager,
@@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
"""
if inspect.iscoroutinefunction(func):
+ # For this branch, we handle async functions like `async def func() -> RInner`.
# In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
async def _wrapper(
@@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
return await func(*args, **kwargs) # type: ignore[misc]
else:
- # The other case here handles both sync functions and those
- # decorated with inlineDeferred.
+ # The other case here handles sync functions including those decorated with
+ # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
@wraps(func)
- def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__()
try:
result = func(*args, **kwargs)
+
if isinstance(result, defer.Deferred):
def call_back(result: R) -> R:
@@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
return result
def err_back(result: R) -> R:
+ # TODO: Pass the error details into `scope.__exit__(...)` for
+ # consistency with the other paths.
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
+ elif inspect.isawaitable(result):
+
+ async def wrap_awaitable() -> Any:
+ try:
+ assert isinstance(result, Awaitable)
+ awaited_result = await result
+ scope.__exit__(None, None, None)
+ return awaited_result
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ # The original method returned an awaitable, eg. a coroutine, so we
+ # create another awaitable wrapping it that calls
+ # `scope.__exit__(...)`.
+ return wrap_awaitable()
else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
-
+ # Just a simple sync function so we can just exit the scope and
+ # return the result without any fuss.
scope.__exit__(None, None, None)
return result
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e28ba84cc2..1bc7d64ad9 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import cast
+from typing import Awaitable, cast
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock
@@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return deferreds
"""
- reactor = MemoryReactorClock()
-
with LoggingContext("root context"):
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
result_d1 = fixture_deferred_func()
- # let the tasks complete
- reactor.pump((2,) * 8)
-
self.assertEqual(self.successResultOf(result_d1), "foo")
# the span should have been reported
@@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with async functions
"""
- reactor = MemoryReactorClock()
-
with LoggingContext("root context"):
@trace_with_opname("fixture_async_func", tracer=self._tracer)
@@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
d1 = defer.ensureDeferred(fixture_async_func())
- # let the tasks complete
- reactor.pump((2,) * 8)
-
self.assertEqual(self.successResultOf(d1), "foo")
# the span should have been reported
@@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
[span.operation_name for span in self._reporter.get_spans()],
["fixture_async_func"],
)
+
+ def test_trace_decorator_awaitable_return(self) -> None:
+ """
+ Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+ with functions that return an awaitable (e.g. a coroutine)
+ """
+ with LoggingContext("root context"):
+ # Something we can return without `await` to get a coroutine
+ async def fixture_async_func() -> str:
+ return "foo"
+
+ # The actual kind of function we want to test that returns an awaitable
+ @trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
+ @tag_args
+ def fixture_awaitable_return_func() -> Awaitable[str]:
+ return fixture_async_func()
+
+ # Something we can run with `defer.ensureDeferred(runner())` and pump the
+ # whole async tasks through to completion.
+ async def runner() -> str:
+ return await fixture_awaitable_return_func()
+
+ d1 = defer.ensureDeferred(runner())
+
+ self.assertEqual(self.successResultOf(d1), "foo")
+
+ # the span should have been reported
+ self.assertEqual(
+ [span.operation_name for span in self._reporter.get_spans()],
+ ["fixture_awaitable_return_func"],
+ )
|