diff options
author | H. Shay <hillerys@element.io> | 2022-01-05 13:15:50 -0800 |
---|---|---|
committer | H. Shay <hillerys@element.io> | 2022-01-05 13:15:50 -0800 |
commit | 55cfb33644bdbbbc935d24ce0d041eb5799a0514 (patch) | |
tree | 65be7485f19072147068af21def689d66464ac2b | |
parent | Merge branch 'develop' into shay/add_types_opentracing.py (diff) | |
download | synapse-shay/add_types_opentracing.py.tar.xz |
some type hints github/shay/add_types_opentracing.py shay/add_types_opentracing.py
-rw-r--r-- | synapse/logging/opentracing.py | 152 |
1 files changed, 97 insertions, 55 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8d59f9e45b..59673e8a57 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -168,10 +168,26 @@ import inspect import logging import re from functools import wraps -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type, Any, Callable, Iterator, Iterable, \ - Union, Match +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Match, + Optional, + Pattern, + Type, + TypeVar, + Union, + cast, + overload, +) import attr +from mypy.nodes import JsonDict from twisted.internet import defer from twisted.web.http import Request @@ -181,11 +197,13 @@ from synapse.config import ConfigError from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: + from opentracing.span import Span, SpanContext + from opentracing.tracer import Reference + from synapse.http.site import SynapseRequest from synapse.server import HomeServer - from opentracing.tracer import Reference - from opentracing.span import Span, SpanContext - from opentracing.scope import Scope + +F = TypeVar("F", bound=Callable[..., Any]) # Helper class @@ -260,7 +278,7 @@ try: def set_process(self, *args: Any, **kwargs: Any) -> None: return self._reporter.set_process(*args, **kwargs) - def report_span(self, span: 'Span') -> None: + def report_span(self, span: "Span") -> None: try: return self._reporter.report_span(span) except Exception: @@ -307,12 +325,14 @@ _homeserver_whitelist: Optional[Pattern[str]] = None Sentinel = object() +R = TypeVar("R") -def only_if_tracing(func: Callable) -> Callable: + +def only_if_tracing(func: Callable[..., R]) -> Callable[..., Optional[R]]: """Executes the function only if we're tracing. Otherwise returns None.""" @wraps(func) - def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[Callable]: + def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[R]: if opentracing: return func(*args, **kwargs) else: @@ -321,7 +341,23 @@ def only_if_tracing(func: Callable) -> Callable: return _only_if_tracing_inner -def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[object]: +@overload +def ensure_active_span( + message: str, +) -> Callable[[Callable[..., None]], Callable[..., None]]: + ... + + +@overload +def ensure_active_span( + message: str, ret: R +) -> Callable[[Callable[..., R]], Callable[..., R]]: + ... + + +def ensure_active_span( + message: str, ret: Any = None +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Executes the operation only if opentracing is enabled and there is an active span. If there is no active span it logs message at the error level. @@ -334,9 +370,9 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj was no active span. """ - def ensure_active_span_inner_1(func: Callable) -> Callable: + def ensure_active_span_inner_1(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def ensure_active_span_inner_2(*args: Any, **kwargs: Any) -> Union[Callable, Optional[object]]: + def ensure_active_span_inner_2(*args, **kwargs) -> Any: if not opentracing: return ret @@ -358,7 +394,7 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj @contextlib.contextmanager -def noop_context_manager(*args: Any, **kwargs: Any) -> Iterator: +def noop_context_manager(*args: Any, **kwargs: Any): """Does exactly what it says on the tin""" # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6 yield @@ -383,7 +419,7 @@ def init_tracer(hs: "HomeServer") -> None: # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - + assert set_homeserver_whitelist is not None set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist) from jaeger_client.metrics.prometheus import PrometheusMetricsFactory @@ -441,13 +477,13 @@ def whitelisted_homeserver(destination: str) -> Union[bool, Optional[Match[str]] # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, - child_of: Union['Span', 'SpanContext', None]=None, - references: Optional[List['Reference']]=None, - tags: Dict['Span', Any]=None, - start_time: float=None, - ignore_active_span: bool=False, - finish_on_close: bool=True, -) -> Union['Scope', Callable]: + child_of: Union["Span", "SpanContext", None] = None, + references: Optional[List["Reference"]] = None, + tags: Optional[dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, +): """Starts an active opentracing span. Note, the scope doesn't become active until it has been entered, however, the span starts from the time this message is called. @@ -472,8 +508,8 @@ def start_active_span( def start_active_span_follows_from( - operation_name: str, contexts: Collection, inherit_force_tracing: bool=False -) -> Union['Scope', Callable]: + operation_name: str, contexts: Collection, inherit_force_tracing: bool = False +): """Starts an active opentracing span, with additional references to previous spans Args: @@ -491,6 +527,7 @@ def start_active_span_follows_from( if inherit_force_tracing and any( is_context_forced_tracing(ctx) for ctx in contexts ): + assert force_tracing is not None force_tracing(scope.span) return scope @@ -498,13 +535,13 @@ def start_active_span_follows_from( def start_active_span_from_request( request: Request, - operation_name, - references=None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -) -> 'SpanContext': + operation_name: str, + references: Optional[List["Reference"]] = None, + tags: Optional[dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, +): """ Extracts a span context from a Twisted Request. args: @@ -539,13 +576,13 @@ def start_active_span_from_request( def start_active_span_from_edu( - edu_content, - operation_name, - references: Optional[list] = None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, + edu_content: JsonDict, + operation_name: str, + references: Optional[List["Reference"]] = None, + tags: Optional[dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, ): """ Extracts a span context from an edu and uses it to start a new active span @@ -565,6 +602,7 @@ def start_active_span_from_edu( "opentracing", {} ) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) + assert span_context_from_string is not None _references = [ opentracing.child_of(span_context_from_string(x)) for x in carrier.get("references", []) @@ -592,53 +630,53 @@ def start_active_span_from_edu( # Opentracing setters for tags, logs, etc @only_if_tracing -def active_span(): +def active_span() -> Optional[Span]: """Get the currently active span, if any""" return opentracing.tracer.active_span @ensure_active_span("set a tag") -def set_tag(key, value): +def set_tag(key: str, value: Union[str, bool, int, float]) -> None: """Sets a tag on the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") -def log_kv(key_values, timestamp=None): +def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None: """Log to the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") -def set_operation_name(operation_name): +def set_operation_name(operation_name: str) -> None: """Sets the operation name of the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @only_if_tracing -def force_tracing(span=Sentinel) -> None: +def force_tracing(span: Union[object, "Span", None] = Sentinel) -> None: """Force sampling for the active/given span and its children. Args: span: span to force tracing for. By default, the active span. """ - if span is Sentinel: - span = opentracing.tracer.active_span if span is None: logger.error("No active span in force_tracing") return + if span is Sentinel: + span = opentracing.tracer.active_span - span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) + span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) # type: ignore[attr-defined] # also set a bit of baggage, so that we have a way of figuring out if # it is enabled later - span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") + span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") # type: ignore[attr-defined] -def is_context_forced_tracing(span_context) -> bool: +def is_context_forced_tracing(span_context: Optional["SpanContext"]) -> bool: """Check if sampling has been force for the given span context.""" if span_context is None: return False @@ -677,6 +715,7 @@ def inject_header_dict( raise ValueError( "destination must be given unless check_destination is False" ) + assert whitelisted_homeserver is not None if not whitelisted_homeserver(destination): return @@ -709,8 +748,10 @@ def inject_response_headers(response_headers: Headers) -> None: response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}") -@ensure_active_span("get the active span context as a dict", ret={}) -def get_active_span_text_map(destination=None): +@ensure_active_span( + "get the active span context as a dict", ret=cast(Dict[str, str], {}) +) +def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]: """ Gets a span context as a dict. This can be used instead of manually injecting a span into an empty carrier. @@ -722,6 +763,7 @@ def get_active_span_text_map(destination=None): dict: the active span's context if opentracing is enabled, otherwise empty. """ + assert whitelisted_homeserver is not None if destination and not whitelisted_homeserver(destination): return {} @@ -734,8 +776,8 @@ def get_active_span_text_map(destination=None): return carrier -@ensure_active_span("get the span context as a string.", ret={}) -def active_span_context_as_string(): +@ensure_active_span("get the span context as a string.", ret="{}") +def active_span_context_as_string() -> str: """ Returns: The active span context encoded as a string. @@ -749,7 +791,7 @@ def active_span_context_as_string(): return json_encoder.encode(carrier) -def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]": +def span_context_from_request(request: Request) -> "Optional['SpanContext']": """Extract an opentracing context from the headers on an HTTP request This is useful when we have received an HTTP request from another part of our @@ -764,17 +806,17 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon @only_if_tracing -def span_context_from_string(carrier): +def span_context_from_string(carrier: str) -> "SpanContext": """ Returns: The active span context decoded from a string. """ - carrier = json_decoder.decode(carrier) - return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) + dict_carrier = json_decoder.decode(carrier) + return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, dict_carrier) @only_if_tracing -def extract_text_map(carrier): +def extract_text_map(carrier: dict) -> "SpanContext": """ Wrapper method for opentracing's tracer.extract for TEXT_MAP. Args: |