summary refs log tree commit diff
diff options
context:
space:
mode:
authorH. Shay <hillerys@element.io>2022-01-05 13:15:50 -0800
committerH. Shay <hillerys@element.io>2022-01-05 13:15:50 -0800
commit55cfb33644bdbbbc935d24ce0d041eb5799a0514 (patch)
tree65be7485f19072147068af21def689d66464ac2b
parentMerge branch 'develop' into shay/add_types_opentracing.py (diff)
downloadsynapse-github/shay/add_types_opentracing.py.tar.xz
-rw-r--r--synapse/logging/opentracing.py152
1 files changed, 97 insertions, 55 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8d59f9e45b..59673e8a57 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,10 +168,26 @@ import inspect
 import logging
 import re
 from functools import wraps
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type, Any, Callable, Iterator, Iterable, \
-    Union, Match
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Match,
+    Optional,
+    Pattern,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
 
 import attr
+from mypy.nodes import JsonDict
 
 from twisted.internet import defer
 from twisted.web.http import Request
@@ -181,11 +197,13 @@ from synapse.config import ConfigError
 from synapse.util import json_decoder, json_encoder
 
 if TYPE_CHECKING:
+    from opentracing.span import Span, SpanContext
+    from opentracing.tracer import Reference
+
     from synapse.http.site import SynapseRequest
     from synapse.server import HomeServer
-    from opentracing.tracer import Reference
-    from opentracing.span import Span, SpanContext
-    from opentracing.scope import Scope
+
+F = TypeVar("F", bound=Callable[..., Any])
 
 # Helper class
 
@@ -260,7 +278,7 @@ try:
         def set_process(self, *args: Any, **kwargs: Any) -> None:
             return self._reporter.set_process(*args, **kwargs)
 
-        def report_span(self, span: 'Span') -> None:
+        def report_span(self, span: "Span") -> None:
             try:
                 return self._reporter.report_span(span)
             except Exception:
@@ -307,12 +325,14 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
 
 Sentinel = object()
 
+R = TypeVar("R")
 
-def only_if_tracing(func: Callable) -> Callable:
+
+def only_if_tracing(func: Callable[..., R]) -> Callable[..., Optional[R]]:
     """Executes the function only if we're tracing. Otherwise returns None."""
 
     @wraps(func)
-    def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[Callable]:
+    def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[R]:
         if opentracing:
             return func(*args, **kwargs)
         else:
@@ -321,7 +341,23 @@ def only_if_tracing(func: Callable) -> Callable:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[object]:
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[..., None]], Callable[..., None]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: R
+) -> Callable[[Callable[..., R]], Callable[..., R]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Any = None
+) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
@@ -334,9 +370,9 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func: Callable) -> Callable:
+    def ensure_active_span_inner_1(func: Callable[..., Any]) -> Callable[..., Any]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args: Any, **kwargs: Any) -> Union[Callable, Optional[object]]:
+        def ensure_active_span_inner_2(*args, **kwargs) -> Any:
             if not opentracing:
                 return ret
 
@@ -358,7 +394,7 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj
 
 
 @contextlib.contextmanager
-def noop_context_manager(*args: Any, **kwargs: Any) -> Iterator:
+def noop_context_manager(*args: Any, **kwargs: Any):
     """Does exactly what it says on the tin"""
     # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
     yield
@@ -383,7 +419,7 @@ def init_tracer(hs: "HomeServer") -> None:
 
     # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
     # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
-
+    assert set_homeserver_whitelist is not None
     set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
 
     from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
@@ -441,13 +477,13 @@ def whitelisted_homeserver(destination: str) -> Union[bool, Optional[Match[str]]
 # Could use kwargs but I want these to be explicit
 def start_active_span(
     operation_name: str,
-    child_of: Union['Span', 'SpanContext', None]=None,
-    references: Optional[List['Reference']]=None,
-    tags: Dict['Span', Any]=None,
-    start_time: float=None,
-    ignore_active_span: bool=False,
-    finish_on_close: bool=True,
-) -> Union['Scope', Callable]:
+    child_of: Union["Span", "SpanContext", None] = None,
+    references: Optional[List["Reference"]] = None,
+    tags: Optional[dict] = None,
+    start_time: Optional[float] = None,
+    ignore_active_span: bool = False,
+    finish_on_close: bool = True,
+):
     """Starts an active opentracing span. Note, the scope doesn't become active
     until it has been entered, however, the span starts from the time this
     message is called.
@@ -472,8 +508,8 @@ def start_active_span(
 
 
 def start_active_span_follows_from(
-    operation_name: str, contexts: Collection, inherit_force_tracing: bool=False
-) -> Union['Scope', Callable]:
+    operation_name: str, contexts: Collection, inherit_force_tracing: bool = False
+):
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -491,6 +527,7 @@ def start_active_span_follows_from(
     if inherit_force_tracing and any(
         is_context_forced_tracing(ctx) for ctx in contexts
     ):
+        assert force_tracing is not None
         force_tracing(scope.span)
 
     return scope
@@ -498,13 +535,13 @@ def start_active_span_follows_from(
 
 def start_active_span_from_request(
     request: Request,
-    operation_name,
-    references=None,
-    tags=None,
-    start_time=None,
-    ignore_active_span=False,
-    finish_on_close=True,
-) -> 'SpanContext':
+    operation_name: str,
+    references: Optional[List["Reference"]] = None,
+    tags: Optional[dict] = None,
+    start_time: Optional[float] = None,
+    ignore_active_span: bool = False,
+    finish_on_close: bool = True,
+):
     """
     Extracts a span context from a Twisted Request.
     args:
@@ -539,13 +576,13 @@ def start_active_span_from_request(
 
 
 def start_active_span_from_edu(
-    edu_content,
-    operation_name,
-    references: Optional[list] = None,
-    tags=None,
-    start_time=None,
-    ignore_active_span=False,
-    finish_on_close=True,
+    edu_content: JsonDict,
+    operation_name: str,
+    references: Optional[List["Reference"]] = None,
+    tags: Optional[dict] = None,
+    start_time: Optional[float] = None,
+    ignore_active_span: bool = False,
+    finish_on_close: bool = True,
 ):
     """
     Extracts a span context from an edu and uses it to start a new active span
@@ -565,6 +602,7 @@ def start_active_span_from_edu(
         "opentracing", {}
     )
     context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+    assert span_context_from_string is not None
     _references = [
         opentracing.child_of(span_context_from_string(x))
         for x in carrier.get("references", [])
@@ -592,53 +630,53 @@ def start_active_span_from_edu(
 
 # Opentracing setters for tags, logs, etc
 @only_if_tracing
-def active_span():
+def active_span() -> Optional[Span]:
     """Get the currently active span, if any"""
     return opentracing.tracer.active_span
 
 
 @ensure_active_span("set a tag")
-def set_tag(key, value):
+def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
     """Sets a tag on the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_tag(key, value)
 
 
 @ensure_active_span("log")
-def log_kv(key_values, timestamp=None):
+def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
     """Log to the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.log_kv(key_values, timestamp)
 
 
 @ensure_active_span("set the traces operation name")
-def set_operation_name(operation_name):
+def set_operation_name(operation_name: str) -> None:
     """Sets the operation name of the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_operation_name(operation_name)
 
 
 @only_if_tracing
-def force_tracing(span=Sentinel) -> None:
+def force_tracing(span: Union[object, "Span", None] = Sentinel) -> None:
     """Force sampling for the active/given span and its children.
 
     Args:
         span: span to force tracing for. By default, the active span.
     """
-    if span is Sentinel:
-        span = opentracing.tracer.active_span
     if span is None:
         logger.error("No active span in force_tracing")
         return
+    if span is Sentinel:
+        span = opentracing.tracer.active_span
 
-    span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+    span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)  # type: ignore[attr-defined]
 
     # also set a bit of baggage, so that we have a way of figuring out if
     # it is enabled later
-    span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
+    span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")  # type: ignore[attr-defined]
 
 
-def is_context_forced_tracing(span_context) -> bool:
+def is_context_forced_tracing(span_context: Optional["SpanContext"]) -> bool:
     """Check if sampling has been force for the given span context."""
     if span_context is None:
         return False
@@ -677,6 +715,7 @@ def inject_header_dict(
             raise ValueError(
                 "destination must be given unless check_destination is False"
             )
+        assert whitelisted_homeserver is not None
         if not whitelisted_homeserver(destination):
             return
 
@@ -709,8 +748,10 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
-def get_active_span_text_map(destination=None):
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
+def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
     injecting a span into an empty carrier.
@@ -722,6 +763,7 @@ def get_active_span_text_map(destination=None):
         dict: the active span's context if opentracing is enabled, otherwise empty.
     """
 
+    assert whitelisted_homeserver is not None
     if destination and not whitelisted_homeserver(destination):
         return {}
 
@@ -734,8 +776,8 @@ def get_active_span_text_map(destination=None):
     return carrier
 
 
-@ensure_active_span("get the span context as a string.", ret={})
-def active_span_context_as_string():
+@ensure_active_span("get the span context as a string.", ret="{}")
+def active_span_context_as_string() -> str:
     """
     Returns:
         The active span context encoded as a string.
@@ -749,7 +791,7 @@ def active_span_context_as_string():
     return json_encoder.encode(carrier)
 
 
-def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+def span_context_from_request(request: Request) -> "Optional['SpanContext']":
     """Extract an opentracing context from the headers on an HTTP request
 
     This is useful when we have received an HTTP request from another part of our
@@ -764,17 +806,17 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
 
 
 @only_if_tracing
-def span_context_from_string(carrier):
+def span_context_from_string(carrier: str) -> "SpanContext":
     """
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json_decoder.decode(carrier)
-    return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+    dict_carrier = json_decoder.decode(carrier)
+    return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, dict_carrier)
 
 
 @only_if_tracing
-def extract_text_map(carrier):
+def extract_text_map(carrier: dict) -> "SpanContext":
     """
     Wrapper method for opentracing's tracer.extract for TEXT_MAP.
     Args: