summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py232
1 files changed, 164 insertions, 68 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 17e729f0c7..482316a1ff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import (
     Any,
     Callable,
     Collection,
+    ContextManager,
     Dict,
     Generator,
     Iterable,
@@ -182,6 +183,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -307,6 +310,19 @@ class SynapseTags:
     # The name of the external cache
     CACHE_NAME = "cache.name"
 
+    # Used to tag function arguments
+    #
+    # Tag a named arg. The name of the argument should be appended to this prefix.
+    FUNC_ARG_PREFIX = "ARG."
+    # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+    FUNC_ARGS = "args"
+    # Tag keyword args
+    FUNC_KWARGS = "kwargs"
+
+    # Some intermediate result that's interesting to the function. The label for
+    # the result should be appended to this prefix.
+    RESULT_PREFIX = "RESULT."
+
 
 class SynapseBaggage:
     FORCE_TRACING = "synapse-force-tracing"
@@ -328,6 +344,7 @@ class _Sentinel(enum.Enum):
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +360,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
         message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -464,7 +502,7 @@ def start_active_span(
     finish_on_close: bool = True,
     *,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +540,7 @@ def start_active_span_follows_from(
     *,
     inherit_force_tracing: bool = False,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -717,7 +755,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -797,75 +837,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def _custom_sync_async_decorator(
+    func: Callable[P, R],
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function with a custom opname.
-
-    See the module's doc string for usage examples.
+    Decorates a function that is sync or async (coroutines), or that returns a Twisted
+    `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+    Example usage:
+    ```py
+    # Decorator to time the function and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+            start_ts = time.time()
+            try:
+                yield
+            finally:
+                end_ts = time.time()
+                duration = end_ts - start_ts
+                logger.info("%s took %s seconds", func.__name__, duration)
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+    ```
 
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func: Callable[P, R]) -> Callable[P, R]:
-        if opentracing is None:
-            return func  # type: ignore[unreachable]
+    if inspect.iscoroutinefunction(func):
 
-        if inspect.iscoroutinefunction(func):
+        @wraps(func)
+        async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)  # type: ignore[misc]
 
-            @wraps(func)
-            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                with start_active_span(opname):
-                    return await func(*args, **kwargs)  # type: ignore[misc]
+    else:
+        # The other case here handles both sync functions and those
+        # decorated with inlineDeferred.
+        @wraps(func)
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
 
-        else:
-            # The other case here handles both sync functions and those
-            # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                scope = start_active_span(opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
+
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
+
+                    result.addCallbacks(call_back, err_back)
+
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
+
+                    scope.__exit__(None, None, None)
+
+                return result
+
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
+
+    return _wrapper  # type: ignore[return-value]
+
+
+def trace_with_opname(
+    opname: str,
+    *,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+    See the module's doc string for usage examples.
+    """
 
-                    return result
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
+        with start_active_span(opname, tracer=tracer):
+            yield
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+        if not opentracing:
+            return func
 
-        return _trace_inner  # type: ignore[return-value]
+        return _custom_sync_async_decorator(func, _wrapping_logic)
 
-    return decorator
+    return _decorator
 
 
 def trace(func: Callable[P, R]) -> Callable[P, R]:
     """
     Decorator to trace a function.
-
     Sets the operation name to that of the function's name.
-
     See the module's doc string for usage examples.
     """
 
@@ -874,22 +956,36 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
+
+    Args:
+        func: `func` is assumed to be a method taking a `self` parameter, or a
+            `classmethod` taking a `cls` parameter. In either case, a tag is not
+            created for this parameter.
     """
 
     if not opentracing:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
-        for i, arg in enumerate(argspec.args[1:]):
-            set_tag("ARG_" + arg, args[i])  # type: ignore[index]
-        set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
-        return func(*args, **kwargs)
-
-    return _tag_args_inner
+        # We use `[1:]` to skip the `self` object reference and `start=1` to
+        # make the index line up with `argspec.args`.
+        #
+        # FIXME: We could update this to handle any type of function by ignoring the
+        #   first argument only if it's named `self` or `cls`. This isn't fool-proof
+        #   but handles the idiomatic cases.
+        for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
+            set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+        set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :]))  # type: ignore[index]
+        set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
+        yield
+
+    return _custom_sync_async_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager