summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-07-21 08:01:52 -0400
committerGitHub <noreply@github.com>2022-07-21 12:01:52 +0000
commit50122754c8743df5c904e81b634fdfdeea64e795 (patch)
tree3c6b3e46aec5a66a05eaffb688154eb87b198fcb /synapse/logging/opentracing.py
parentUse cache store remove base slaved (#13329) (diff)
downloadsynapse-50122754c8743df5c904e81b634fdfdeea64e795.tar.xz
Add missing types to opentracing. (#13345)
After this change `synapse.logging` is fully typed.
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py44
1 files changed, 35 insertions, 9 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 17e729f0c7..ad5cbf46a4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -182,6 +182,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
         message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -464,7 +488,7 @@ def start_active_span(
     finish_on_close: bool = True,
     *,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +526,7 @@ def start_active_span_follows_from(
     *,
     inherit_force_tracing: bool = False,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])  # type: ignore[index]
         set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
+        set_tag("kwargs", str(kwargs))
         return func(*args, **kwargs)
 
     return _tag_args_inner