summary refs log tree commit diff
path: root/synapse/logging
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging')
-rw-r--r--synapse/logging/context.py149
-rw-r--r--synapse/logging/opentracing.py24
-rw-r--r--synapse/logging/scopecontextmanager.py2
3 files changed, 119 insertions, 56 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py

index d8ae3188b7..25e78cc82f 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -22,20 +22,33 @@ them. See doc/log_contexts.rst for details on how this works. """ -import inspect import logging import threading import typing import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal from twisted.internet import defer, threads +from twisted.python.threadpool import ThreadPool if TYPE_CHECKING: from synapse.logging.scopecontextmanager import _LogContextScope + from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -66,7 +79,7 @@ except Exception: # a hook which can be set during testing to assert that we aren't abusing logcontexts. -def logcontext_error(msg: str): +def logcontext_error(msg: str) -> None: logger.warning(msg) @@ -223,22 +236,19 @@ class _Sentinel: def __str__(self) -> str: return "sentinel" - def copy_to(self, record): - pass - - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def stop(self, rusage: "Optional[resource.struct_rusage]"): + def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def add_database_transaction(self, duration_sec): + def add_database_transaction(self, duration_sec: float) -> None: pass - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: pass - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: pass def __bool__(self) -> Literal[False]: @@ -379,7 +389,12 @@ class LoggingContext: ) return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: @@ -399,17 +414,6 @@ class LoggingContext: # recorded against the correct metrics. self.finished = True - def copy_to(self, record) -> None: - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # we track the current request - record.request = self.request - - # we also track the current scope: - record.scope = self.scope - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """ Record that this logcontext is currently running. @@ -626,7 +630,12 @@ class PreserveLoggingContext: def __enter__(self) -> None: self._old_context = set_current_context(self._new_context) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: context = set_current_context(self._old_context) if context != self._new_context: @@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) -def preserve_fn(f): +R = TypeVar("R") + + +@overload +def preserve_fn( # type: ignore[misc] + f: Callable[..., Awaitable[R]], +) -> Callable[..., "defer.Deferred[R]"]: + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: + ... + + +def preserve_fn( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ] +) -> Callable[..., "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): + def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g -def run_in_background(f, *args, **kwargs) -> defer.Deferred: +@overload +def run_in_background( # type: ignore[misc] + f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def run_in_background( + f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + ... + + +def run_in_background( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # At this point we should have a Deferred, if not then f was a synchronous # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): + # `res` is not a `Deferred` and not a `Coroutine`. + # There are no other types of `Awaitable`s we expect to encounter in Synapse. + assert not isinstance(res, Awaitable) + return defer.succeed(res) if res.called and not res.paused: @@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: return res -def make_deferred_yieldable(deferred): - """Given a deferred (or coroutine), make it follow the Synapse logcontext - rules: +T = TypeVar("T") + - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). +def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed, essentially does nothing (just returns another + completed deferred with the result/failure). If the deferred has not yet completed, resets the logcontext before returning a deferred. Then, when the deferred completes, restores the @@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ - if inspect.isawaitable(deferred): - # If we're given a coroutine we convert it to a deferred so that we - # run it and find out if it immediately finishes, it it does then we - # don't need to fiddle with log contexts at all and can return - # immediately. - deferred = defer.ensureDeferred(deferred) - - if not isinstance(deferred, defer.Deferred): - return deferred - if deferred.called and not deferred.paused: # it looks like this deferred is ready to run any callbacks we give it # immediately. We may as well optimise out the logcontext faffery. @@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: return result -def defer_to_thread(reactor, f, *args, **kwargs): +def defer_to_thread( + reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and returns the result as a Deferred. @@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): +def defer_to_threadpool( + reactor: "ISynapseReactor", + threadpool: ThreadPool, + f: Callable[..., R], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles logcontexts correctly. @@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): assert isinstance(curr_context, LoggingContext) parent_context = curr_context - def g(): + def g() -> R: with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 20d23a4260..5d93ab07f1 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -222,8 +222,8 @@ try: tags = opentracing.tags except ImportError: - opentracing = None - tags = _DummyTagNames + opentracing = None # type: ignore[assignment] + tags = _DummyTagNames # type: ignore[assignment] try: from jaeger_client import Config as JaegerConfig @@ -366,7 +366,7 @@ def init_tracer(hs: "HomeServer"): global opentracing if not hs.config.tracing.opentracer_enabled: # We don't have a tracer - opentracing = None + opentracing = None # type: ignore[assignment] return if not opentracing or not JaegerConfig: @@ -452,7 +452,7 @@ def start_active_span( """ if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] return opentracing.tracer.start_active_span( operation_name, @@ -477,7 +477,7 @@ def start_active_span_follows_from( forced, the new span will also have tracing forced. """ if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span(operation_name, references=references) @@ -514,7 +514,7 @@ def start_active_span_from_request( # Also, twisted uses byte arrays while opentracing expects strings. if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] header_dict = { k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() @@ -553,7 +553,7 @@ def start_active_span_from_edu( references = references or [] if opentracing is None: - return noop_context_manager() + return noop_context_manager() # type: ignore[unreachable] carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} @@ -594,18 +594,21 @@ def active_span(): @ensure_active_span("set a tag") def set_tag(key, value): """Sets a tag on the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") def log_kv(key_values, timestamp=None): """Log to the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") def set_operation_name(operation_name): """Sets the operation name of the active span""" + assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @@ -674,6 +677,7 @@ def inject_header_dict( span = opentracing.tracer.active_span carrier: Dict[str, str] = {} + assert span is not None opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -716,6 +720,7 @@ def get_active_span_text_map(destination=None): return {} carrier: Dict[str, str] = {} + assert opentracing.tracer.active_span is not None opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) @@ -731,6 +736,7 @@ def active_span_context_as_string(): """ carrier: Dict[str, str] = {} if opentracing: + assert opentracing.tracer.active_span is not None opentracing.tracer.inject( opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier ) @@ -773,7 +779,7 @@ def trace(func=None, opname=None): def decorator(func): if opentracing is None: - return func + return func # type: ignore[unreachable] _opname = opname if opname else func.__name__ @@ -864,7 +870,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): """ if opentracing is None: - yield + yield # type: ignore[unreachable] return request_tags = { diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index b1e8e08fe9..db8ca2c049 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py
@@ -71,7 +71,7 @@ class LogContextScopeManager(ScopeManager): if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") - return Scope(None, span) + return Scope(None, span) # type: ignore[arg-type] elif ctx.scope is not None: # We want the logging scope to look exactly the same so we give it # a blank suffix