diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 17e729f0c7..ad5cbf46a4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -182,6 +182,8 @@ from typing import (
Type,
TypeVar,
Union,
+ cast,
+ overload,
)
import attr
@@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
P = ParamSpec("P")
R = TypeVar("R")
+T = TypeVar("T")
def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
return _only_if_tracing_inner
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+ message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+ ...
+
+
+@overload
+def ensure_active_span(
+ message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+ ...
+
+
+def ensure_active_span(
+ message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
Args:
message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
- ret (object): return value if opentracing is None or there is no active span.
+ ret: return value if opentracing is None or there is no active span.
- Returns (object): The result of the func or ret if opentracing is disabled or there
+ Returns:
+ The result of the func, falling back to ret if opentracing is disabled or there
was no active span.
"""
- def ensure_active_span_inner_1(func):
+ def ensure_active_span_inner_1(
+ func: Callable[P, R]
+ ) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
- def ensure_active_span_inner_2(*args, **kwargs):
+ def ensure_active_span_inner_2(
+ *args: P.args, **kwargs: P.kwargs
+ ) -> Union[Optional[T], R]:
if not opentracing:
return ret
@@ -464,7 +488,7 @@ def start_active_span(
finish_on_close: bool = True,
*,
tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
"""Starts an active opentracing span.
Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +526,7 @@ def start_active_span_follows_from(
*,
inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans
Args:
@@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+ "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
@@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
- set_tag("kwargs", kwargs)
+ set_tag("kwargs", str(kwargs))
return func(*args, **kwargs)
return _tag_args_inner
|