diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8d59f9e45b..59673e8a57 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,10 +168,26 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type, Any, Callable, Iterator, Iterable, \
- Union, Match
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Match,
+ Optional,
+ Pattern,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
import attr
+from mypy.nodes import JsonDict
from twisted.internet import defer
from twisted.web.http import Request
@@ -181,11 +197,13 @@ from synapse.config import ConfigError
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
+ from opentracing.span import Span, SpanContext
+ from opentracing.tracer import Reference
+
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
- from opentracing.tracer import Reference
- from opentracing.span import Span, SpanContext
- from opentracing.scope import Scope
+
+F = TypeVar("F", bound=Callable[..., Any])
# Helper class
@@ -260,7 +278,7 @@ try:
def set_process(self, *args: Any, **kwargs: Any) -> None:
return self._reporter.set_process(*args, **kwargs)
- def report_span(self, span: 'Span') -> None:
+ def report_span(self, span: "Span") -> None:
try:
return self._reporter.report_span(span)
except Exception:
@@ -307,12 +325,14 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
Sentinel = object()
+R = TypeVar("R")
-def only_if_tracing(func: Callable) -> Callable:
+
+def only_if_tracing(func: Callable[..., R]) -> Callable[..., Optional[R]]:
"""Executes the function only if we're tracing. Otherwise returns None."""
@wraps(func)
- def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[Callable]:
+ def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[R]:
if opentracing:
return func(*args, **kwargs)
else:
@@ -321,7 +341,23 @@ def only_if_tracing(func: Callable) -> Callable:
return _only_if_tracing_inner
-def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[object]:
+@overload
+def ensure_active_span(
+ message: str,
+) -> Callable[[Callable[..., None]], Callable[..., None]]:
+ ...
+
+
+@overload
+def ensure_active_span(
+ message: str, ret: R
+) -> Callable[[Callable[..., R]], Callable[..., R]]:
+ ...
+
+
+def ensure_active_span(
+ message: str, ret: Any = None
+) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
@@ -334,9 +370,9 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj
was no active span.
"""
- def ensure_active_span_inner_1(func: Callable) -> Callable:
+ def ensure_active_span_inner_1(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
- def ensure_active_span_inner_2(*args: Any, **kwargs: Any) -> Union[Callable, Optional[object]]:
+ def ensure_active_span_inner_2(*args, **kwargs) -> Any:
if not opentracing:
return ret
@@ -358,7 +394,7 @@ def ensure_active_span(message: str, ret: Optional[object]=None) -> Optional[obj
@contextlib.contextmanager
-def noop_context_manager(*args: Any, **kwargs: Any) -> Iterator:
+def noop_context_manager(*args: Any, **kwargs: Any):
"""Does exactly what it says on the tin"""
# TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
yield
@@ -383,7 +419,7 @@ def init_tracer(hs: "HomeServer") -> None:
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
-
+ assert set_homeserver_whitelist is not None
set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
@@ -441,13 +477,13 @@ def whitelisted_homeserver(destination: str) -> Union[bool, Optional[Match[str]]
# Could use kwargs but I want these to be explicit
def start_active_span(
operation_name: str,
- child_of: Union['Span', 'SpanContext', None]=None,
- references: Optional[List['Reference']]=None,
- tags: Dict['Span', Any]=None,
- start_time: float=None,
- ignore_active_span: bool=False,
- finish_on_close: bool=True,
-) -> Union['Scope', Callable]:
+ child_of: Union["Span", "SpanContext", None] = None,
+ references: Optional[List["Reference"]] = None,
+ tags: Optional[dict] = None,
+ start_time: Optional[float] = None,
+ ignore_active_span: bool = False,
+ finish_on_close: bool = True,
+):
"""Starts an active opentracing span. Note, the scope doesn't become active
until it has been entered, however, the span starts from the time this
message is called.
@@ -472,8 +508,8 @@ def start_active_span(
def start_active_span_follows_from(
- operation_name: str, contexts: Collection, inherit_force_tracing: bool=False
-) -> Union['Scope', Callable]:
+ operation_name: str, contexts: Collection, inherit_force_tracing: bool = False
+):
"""Starts an active opentracing span, with additional references to previous spans
Args:
@@ -491,6 +527,7 @@ def start_active_span_follows_from(
if inherit_force_tracing and any(
is_context_forced_tracing(ctx) for ctx in contexts
):
+ assert force_tracing is not None
force_tracing(scope.span)
return scope
@@ -498,13 +535,13 @@ def start_active_span_follows_from(
def start_active_span_from_request(
request: Request,
- operation_name,
- references=None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
-) -> 'SpanContext':
+ operation_name: str,
+ references: Optional[List["Reference"]] = None,
+ tags: Optional[dict] = None,
+ start_time: Optional[float] = None,
+ ignore_active_span: bool = False,
+ finish_on_close: bool = True,
+):
"""
Extracts a span context from a Twisted Request.
args:
@@ -539,13 +576,13 @@ def start_active_span_from_request(
def start_active_span_from_edu(
- edu_content,
- operation_name,
- references: Optional[list] = None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
+ edu_content: JsonDict,
+ operation_name: str,
+ references: Optional[List["Reference"]] = None,
+ tags: Optional[dict] = None,
+ start_time: Optional[float] = None,
+ ignore_active_span: bool = False,
+ finish_on_close: bool = True,
):
"""
Extracts a span context from an edu and uses it to start a new active span
@@ -565,6 +602,7 @@ def start_active_span_from_edu(
"opentracing", {}
)
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ assert span_context_from_string is not None
_references = [
opentracing.child_of(span_context_from_string(x))
for x in carrier.get("references", [])
@@ -592,53 +630,53 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc
@only_if_tracing
-def active_span():
+def active_span() -> Optional[Span]:
"""Get the currently active span, if any"""
return opentracing.tracer.active_span
@ensure_active_span("set a tag")
-def set_tag(key, value):
+def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
"""Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
-def log_kv(key_values, timestamp=None):
+def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
"""Log to the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
-def set_operation_name(operation_name):
+def set_operation_name(operation_name: str) -> None:
"""Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@only_if_tracing
-def force_tracing(span=Sentinel) -> None:
+def force_tracing(span: Union[object, "Span", None] = Sentinel) -> None:
"""Force sampling for the active/given span and its children.
Args:
span: span to force tracing for. By default, the active span.
"""
- if span is Sentinel:
- span = opentracing.tracer.active_span
if span is None:
logger.error("No active span in force_tracing")
return
+ if span is Sentinel:
+ span = opentracing.tracer.active_span
- span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+ span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) # type: ignore[attr-defined]
# also set a bit of baggage, so that we have a way of figuring out if
# it is enabled later
- span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
+ span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") # type: ignore[attr-defined]
-def is_context_forced_tracing(span_context) -> bool:
+def is_context_forced_tracing(span_context: Optional["SpanContext"]) -> bool:
"""Check if sampling has been force for the given span context."""
if span_context is None:
return False
@@ -677,6 +715,7 @@ def inject_header_dict(
raise ValueError(
"destination must be given unless check_destination is False"
)
+ assert whitelisted_homeserver is not None
if not whitelisted_homeserver(destination):
return
@@ -709,8 +748,10 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
-@ensure_active_span("get the active span context as a dict", ret={})
-def get_active_span_text_map(destination=None):
+@ensure_active_span(
+ "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
+def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier.
@@ -722,6 +763,7 @@ def get_active_span_text_map(destination=None):
dict: the active span's context if opentracing is enabled, otherwise empty.
"""
+ assert whitelisted_homeserver is not None
if destination and not whitelisted_homeserver(destination):
return {}
@@ -734,8 +776,8 @@ def get_active_span_text_map(destination=None):
return carrier
-@ensure_active_span("get the span context as a string.", ret={})
-def active_span_context_as_string():
+@ensure_active_span("get the span context as a string.", ret="{}")
+def active_span_context_as_string() -> str:
"""
Returns:
The active span context encoded as a string.
@@ -749,7 +791,7 @@ def active_span_context_as_string():
return json_encoder.encode(carrier)
-def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+def span_context_from_request(request: Request) -> "Optional['SpanContext']":
"""Extract an opentracing context from the headers on an HTTP request
This is useful when we have received an HTTP request from another part of our
@@ -764,17 +806,17 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
@only_if_tracing
-def span_context_from_string(carrier):
+def span_context_from_string(carrier: str) -> "SpanContext":
"""
Returns:
The active span context decoded from a string.
"""
- carrier = json_decoder.decode(carrier)
- return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ dict_carrier = json_decoder.decode(carrier)
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, dict_carrier)
@only_if_tracing
-def extract_text_map(carrier):
+def extract_text_map(carrier: dict) -> "SpanContext":
"""
Wrapper method for opentracing's tracer.extract for TEXT_MAP.
Args:
|