diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c70eee649c..75217e3f45 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,6 +171,7 @@ from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
ContextManager,
@@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
"""
if inspect.iscoroutinefunction(func):
+ # For this branch, we handle async functions like `async def func() -> RInner`.
# In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
async def _wrapper(
@@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
return await func(*args, **kwargs) # type: ignore[misc]
else:
- # The other case here handles both sync functions and those
- # decorated with inlineDeferred.
+ # The other case here handles sync functions including those decorated with
+ # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
@wraps(func)
- def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__()
try:
result = func(*args, **kwargs)
+
if isinstance(result, defer.Deferred):
def call_back(result: R) -> R:
@@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
return result
def err_back(result: R) -> R:
+ # TODO: Pass the error details into `scope.__exit__(...)` for
+ # consistency with the other paths.
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
+ elif inspect.isawaitable(result):
+
+ async def wrap_awaitable() -> Any:
+ try:
+ assert isinstance(result, Awaitable)
+ awaited_result = await result
+ scope.__exit__(None, None, None)
+ return awaited_result
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ # The original method returned an awaitable, eg. a coroutine, so we
+ # create another awaitable wrapping it that calls
+ # `scope.__exit__(...)`.
+ return wrap_awaitable()
else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
-
+ # Just a simple sync function so we can just exit the scope and
+ # return the result without any fuss.
scope.__exit__(None, None, None)
return result
|