summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13453.misc1
-rw-r--r--synapse/logging/opentracing.py158
-rw-r--r--tests/logging/test_opentracing.py83
3 files changed, 186 insertions, 56 deletions
diff --git a/changelog.d/13453.misc b/changelog.d/13453.misc
new file mode 100644
index 0000000000..d30c5230c8
--- /dev/null
+++ b/changelog.d/13453.misc
@@ -0,0 +1 @@
+Allow use of both `@trace` and `@tag_args` stacked on the same function (tracing).
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fa3f76c27f..d1fa2cf8ae 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import (
     Any,
     Callable,
     Collection,
+    ContextManager,
     Dict,
     Generator,
     Iterable,
@@ -823,75 +824,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def _custom_sync_async_decorator(
+    func: Callable[P, R],
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function with a custom opname.
-
-    See the module's doc string for usage examples.
+    Decorates a function that is sync or async (coroutines), or that returns a Twisted
+    `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+    Example usage:
+    ```py
+    # Decorator to time the function and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+            start_ts = time.time()
+            try:
+                yield
+            finally:
+                end_ts = time.time()
+                duration = end_ts - start_ts
+                logger.info("%s took %s seconds", func.__name__, duration)
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+    ```
 
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func: Callable[P, R]) -> Callable[P, R]:
-        if opentracing is None:
-            return func  # type: ignore[unreachable]
+    if inspect.iscoroutinefunction(func):
 
-        if inspect.iscoroutinefunction(func):
+        @wraps(func)
+        async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)  # type: ignore[misc]
 
-            @wraps(func)
-            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                with start_active_span(opname):
-                    return await func(*args, **kwargs)  # type: ignore[misc]
+    else:
+        # The other case here handles both sync functions and those
+        # decorated with inlineDeferred.
+        @wraps(func)
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
 
-        else:
-            # The other case here handles both sync functions and those
-            # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                scope = start_active_span(opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
+
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
+
+                    result.addCallbacks(call_back, err_back)
+
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
+
+                    scope.__exit__(None, None, None)
 
-                    return result
+                return result
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
 
-        return _trace_inner  # type: ignore[return-value]
+    return _wrapper  # type: ignore[return-value]
 
-    return decorator
+
+def trace_with_opname(
+    opname: str,
+    *,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+    See the module's doc string for usage examples.
+    """
+
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
+        with start_active_span(opname, tracer=tracer):
+            yield
+
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+        if not opentracing:
+            return func
+
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+
+    return _decorator
 
 
 def trace(func: Callable[P, R]) -> Callable[P, R]:
     """
     Decorator to trace a function.
-
     Sets the operation name to that of the function's name.
-
     See the module's doc string for usage examples.
     """
 
@@ -900,7 +943,7 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
 
     Args:
         func: `func` is assumed to be a method taking a `self` parameter, or a
@@ -911,22 +954,25 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     if not opentracing:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
         # We use `[1:]` to skip the `self` object reference and `start=1` to
         # make the index line up with `argspec.args`.
         #
-        # FIXME: We could update this handle any type of function by ignoring the
+        # FIXME: We could update this to handle any type of function by ignoring the
         #   first argument only if it's named `self` or `cls`. This isn't fool-proof
         #   but handles the idiomatic cases.
         for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
             set_tag("ARG_" + argspec.args[i], str(arg))
         set_tag("args", str(args[len(argspec.args) :]))  # type: ignore[index]
         set_tag("kwargs", str(kwargs))
-        return func(*args, **kwargs)
+        yield
 
-    return _tag_args_inner
+    return _custom_sync_async_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 3b14c76d7e..0917e478a5 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -25,6 +25,8 @@ from synapse.logging.context import (
 from synapse.logging.opentracing import (
     start_active_span,
     start_active_span_follows_from,
+    tag_args,
+    trace_with_opname,
 )
 from synapse.util import Clock
 
@@ -38,8 +40,12 @@ try:
 except ImportError:
     jaeger_client = None  # type: ignore
 
+import logging
+
 from tests.unittest import TestCase
 
+logger = logging.getLogger(__name__)
+
 
 class LogContextScopeManagerTestCase(TestCase):
     """
@@ -194,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
             self._reporter.get_spans(),
             [scopes[1].span, scopes[2].span, scopes[0].span],
         )
+
+    def test_trace_decorator_sync(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with sync functions
+        """
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_sync_func", tracer=self._tracer)
+            @tag_args
+            def fixture_sync_func() -> str:
+                return "foo"
+
+            result = fixture_sync_func()
+            self.assertEqual(result, "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_sync_func"],
+        )
+
+    def test_trace_decorator_deferred(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with functions that return deferreds
+        """
+        reactor = MemoryReactorClock()
+
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
+            @tag_args
+            def fixture_deferred_func() -> "defer.Deferred[str]":
+                d1: defer.Deferred[str] = defer.Deferred()
+                d1.callback("foo")
+                return d1
+
+            result_d1 = fixture_deferred_func()
+
+            # let the tasks complete
+            reactor.pump((2,) * 8)
+
+            self.assertEqual(self.successResultOf(result_d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_deferred_func"],
+        )
+
+    def test_trace_decorator_async(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with async functions
+        """
+        reactor = MemoryReactorClock()
+
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_async_func", tracer=self._tracer)
+            @tag_args
+            async def fixture_async_func() -> str:
+                return "foo"
+
+            d1 = defer.ensureDeferred(fixture_async_func())
+
+            # let the tasks complete
+            reactor.pump((2,) * 8)
+
+            self.assertEqual(self.successResultOf(d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_async_func"],
+        )