summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-27 11:17:33 +0100
committerGitHub <noreply@github.com>2022-05-27 11:17:33 +0100
commit3503f42741d438b67490bc774ee2c3a856f6bc81 (patch)
tree2b39ed642df3ac4ae149bdde6c13d00edc5683ff
parentFix room deletion (#12889) (diff)
downloadsynapse-3503f42741d438b67490bc774ee2c3a856f6bc81.tar.xz
Easy type hints in synapse.logging.opentracing (#12894)
-rw-r--r--changelog.d/12894.misc1
-rw-r--r--synapse/config/tracer.py6
-rw-r--r--synapse/logging/opentracing.py114
-rw-r--r--synapse/metrics/background_process_metrics.py9
4 files changed, 73 insertions, 57 deletions
diff --git a/changelog.d/12894.misc b/changelog.d/12894.misc
new file mode 100644
index 0000000000..646a62fccb
--- /dev/null
+++ b/changelog.d/12894.misc
@@ -0,0 +1 @@
+Add type annotations to `synapse.logging.opentracing`.
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 3472a9a01b..ae68a3dd1a 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Set
+from typing import Any, List, Set
 
 from synapse.types import JsonDict
 from synapse.util.check_dependencies import DependencyException, check_requirements
@@ -49,7 +49,9 @@ class TracerConfig(Config):
 
         # The tracer is enabled so sanitize the config
 
-        self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
+        self.opentracer_whitelist: List[str] = opentracing_config.get(
+            "homeserver_whitelist", []
+        )
         if not isinstance(self.opentracer_whitelist, list):
             raise ConfigError("Tracer homeserver_whitelist config is malformed")
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index a02b5bf6bd..903ec40c86 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,9 +168,24 @@ import inspect
 import logging
 import re
 from functools import wraps
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Collection,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Pattern,
+    Type,
+    TypeVar,
+    Union,
+)
 
 import attr
+from typing_extensions import ParamSpec
 
 from twisted.internet import defer
 from twisted.web.http import Request
@@ -256,7 +271,7 @@ try:
         def set_process(self, *args, **kwargs):
             return self._reporter.set_process(*args, **kwargs)
 
-        def report_span(self, span):
+        def report_span(self, span: "opentracing.Span") -> None:
             try:
                 return self._reporter.report_span(span)
             except Exception:
@@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
 Sentinel = object()
 
 
-def only_if_tracing(func):
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     """Executes the function only if we're tracing. Otherwise returns None."""
 
     @wraps(func)
-    def _only_if_tracing_inner(*args, **kwargs):
+    def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
         if opentracing:
             return func(*args, **kwargs)
         else:
-            return
+            return None
 
     return _only_if_tracing_inner
 
@@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None):
     return ensure_active_span_inner_1
 
 
-@contextlib.contextmanager
-def noop_context_manager(*args, **kwargs):
-    """Does exactly what it says on the tin"""
-    # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
-    yield
-
-
 # Setup
 
 
-def init_tracer(hs: "HomeServer"):
+def init_tracer(hs: "HomeServer") -> None:
     """Set the whitelists and initialise the JaegerClient tracer"""
     global opentracing
     if not hs.config.tracing.opentracer_enabled:
@@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"):
 
 
 @only_if_tracing
-def set_homeserver_whitelist(homeserver_whitelist):
+def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
     """Sets the homeserver whitelist
 
     Args:
-        homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
+        homeserver_whitelist: regexes specifying whitelisted homeservers
     """
     global _homeserver_whitelist
     if homeserver_whitelist:
@@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist):
 
 
 @only_if_tracing
-def whitelisted_homeserver(destination):
+def whitelisted_homeserver(destination: str) -> bool:
     """Checks if a destination matches the whitelist
 
     Args:
-        destination (str)
+        destination
     """
 
     if _homeserver_whitelist:
-        return _homeserver_whitelist.match(destination)
+        return _homeserver_whitelist.match(destination) is not None
     return False
 
 
@@ -457,11 +469,11 @@ def start_active_span(
     Args:
         See opentracing.tracer
     Returns:
-        scope (Scope) or noop_context_manager
+        scope (Scope) or contextlib.nullcontext
     """
 
     if opentracing is None:
-        return noop_context_manager()  # type: ignore[unreachable]
+        return contextlib.nullcontext()  # type: ignore[unreachable]
 
     if tracer is None:
         # use the global tracer by default
@@ -505,7 +517,7 @@ def start_active_span_follows_from(
         tracer: override the opentracing tracer. By default the global tracer is used.
     """
     if opentracing is None:
-        return noop_context_manager()  # type: ignore[unreachable]
+        return contextlib.nullcontext()  # type: ignore[unreachable]
 
     references = [opentracing.follows_from(context) for context in contexts]
     scope = start_active_span(
@@ -525,19 +537,19 @@ def start_active_span_follows_from(
 
 
 def start_active_span_from_edu(
-    edu_content,
-    operation_name,
-    references: Optional[list] = None,
-    tags=None,
-    start_time=None,
-    ignore_active_span=False,
-    finish_on_close=True,
-):
+    edu_content: Dict[str, Any],
+    operation_name: str,
+    references: Optional[List["opentracing.Reference"]] = None,
+    tags: Optional[Dict] = None,
+    start_time: Optional[float] = None,
+    ignore_active_span: bool = False,
+    finish_on_close: bool = True,
+) -> "opentracing.Scope":
     """
     Extracts a span context from an edu and uses it to start a new active span
 
     Args:
-        edu_content (dict): and edu_content with a `context` field whose value is
+        edu_content: an edu_content with a `context` field whose value is
         canonical json for a dict which contains opentracing information.
 
         For the other args see opentracing.tracer
@@ -545,7 +557,7 @@ def start_active_span_from_edu(
     references = references or []
 
     if opentracing is None:
-        return noop_context_manager()  # type: ignore[unreachable]
+        return contextlib.nullcontext()  # type: ignore[unreachable]
 
     carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
         "opentracing", {}
@@ -578,27 +590,27 @@ def start_active_span_from_edu(
 
 # Opentracing setters for tags, logs, etc
 @only_if_tracing
-def active_span():
+def active_span() -> Optional["opentracing.Span"]:
     """Get the currently active span, if any"""
     return opentracing.tracer.active_span
 
 
 @ensure_active_span("set a tag")
-def set_tag(key, value):
+def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
     """Sets a tag on the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_tag(key, value)
 
 
 @ensure_active_span("log")
-def log_kv(key_values, timestamp=None):
+def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
     """Log to the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.log_kv(key_values, timestamp)
 
 
 @ensure_active_span("set the traces operation name")
-def set_operation_name(operation_name):
+def set_operation_name(operation_name: str) -> None:
     """Sets the operation name of the active span"""
     assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None:
     span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
 
 
-def is_context_forced_tracing(span_context) -> bool:
+def is_context_forced_tracing(
+    span_context: Optional["opentracing.SpanContext"],
+) -> bool:
     """Check if sampling has been force for the given span context."""
     if span_context is None:
         return False
@@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None:
 
 
 @ensure_active_span("get the active span context as a dict", ret={})
-def get_active_span_text_map(destination=None):
+def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
     injecting a span into an empty carrier.
 
     Args:
-        destination (str): the name of the remote server.
+        destination: the name of the remote server.
 
     Returns:
         dict: the active span's context if opentracing is enabled, otherwise empty.
@@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None):
 
 
 @ensure_active_span("get the span context as a string.", ret={})
-def active_span_context_as_string():
+def active_span_context_as_string() -> str:
     """
     Returns:
         The active span context encoded as a string.
@@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
 
 
 @only_if_tracing
-def span_context_from_string(carrier):
+def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]:
     """
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json_decoder.decode(carrier)
-    return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+    payload: Dict[str, str] = json_decoder.decode(carrier)
+    return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload)
 
 
 @only_if_tracing
-def extract_text_map(carrier):
+def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]:
     """
     Wrapper method for opentracing's tracer.extract for TEXT_MAP.
     Args:
-        carrier (dict): a dict possibly containing a span context.
+        carrier: a dict possibly containing a span context.
 
     Returns:
         The active span context extracted from carrier.
@@ -843,7 +857,7 @@ def trace(func=None, opname=None):
         return decorator
 
 
-def tag_args(func):
+def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
     Tags all of the args to the active span.
     """
@@ -852,11 +866,11 @@ def tag_args(func):
         return func
 
     @wraps(func)
-    def _tag_args_inner(*args, **kwargs):
+    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
         argspec = inspect.getfullargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
-            set_tag("ARG_" + arg, args[i])
-        set_tag("args", args[len(argspec.args) :])
+            set_tag("ARG_" + arg, args[i])  # type: ignore[index]
+        set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
         set_tag("kwargs", kwargs)
         return func(*args, **kwargs)
 
@@ -864,7 +878,9 @@ def tag_args(func):
 
 
 @contextlib.contextmanager
-def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+def trace_servlet(
+    request: "SynapseRequest", extract_context: bool = False
+) -> Generator[None, None, None]:
     """Returns a context manager which traces a request. It starts a span
     with some servlet specific tags such as the request metrics name and
     request information.
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 298809742a..eef3462e10 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+from contextlib import nullcontext
 from functools import wraps
 from types import TracebackType
 from typing import (
@@ -41,11 +42,7 @@ from synapse.logging.context import (
     LoggingContext,
     PreserveLoggingContext,
 )
-from synapse.logging.opentracing import (
-    SynapseTags,
-    noop_context_manager,
-    start_active_span,
-)
+from synapse.logging.opentracing import SynapseTags, start_active_span
 from synapse.metrics._types import Collector
 
 if TYPE_CHECKING:
@@ -238,7 +235,7 @@ def run_as_background_process(
                         f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
                     )
                 else:
-                    ctx = noop_context_manager()
+                    ctx = nullcontext()
                 with ctx:
                     return await func(*args, **kwargs)
             except Exception: