diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index d8ae3188b7..d4ee893376 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,20 +22,33 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
-import inspect
import logging
import threading
import typing
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
+from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
+ from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@@ -55,7 +68,6 @@ try:
def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
return resource.getrusage(RUSAGE_THREAD)
-
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage.
@@ -66,7 +78,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
logger.warning(msg)
@@ -223,22 +235,19 @@ class _Sentinel:
def __str__(self) -> str:
return "sentinel"
- def copy_to(self, record):
- pass
-
- def start(self, rusage: "Optional[resource.struct_rusage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
- def stop(self, rusage: "Optional[resource.struct_rusage]"):
+ def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
- def add_database_transaction(self, duration_sec):
+ def add_database_transaction(self, duration_sec: float) -> None:
pass
- def add_database_scheduled(self, sched_sec):
+ def add_database_scheduled(self, sched_sec: float) -> None:
pass
- def record_event_fetch(self, event_count):
+ def record_event_fetch(self, event_count: int) -> None:
pass
def __bool__(self) -> Literal[False]:
@@ -379,7 +388,12 @@ class LoggingContext:
)
return self
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@@ -399,17 +413,6 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True
- def copy_to(self, record) -> None:
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
-
- # we track the current request
- record.request = self.request
-
- # we also track the current scope:
- record.scope = self.scope
-
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@@ -626,7 +629,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@@ -711,16 +719,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
-def preserve_fn(f):
+R = TypeVar("R")
+
+
+@overload
+def preserve_fn( # type: ignore[misc]
+ f: Callable[..., Awaitable[R]],
+) -> Callable[..., "defer.Deferred[R]"]:
+ # The `type: ignore[misc]` above suppresses
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+ ...
+
+
+@overload
+def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
+ ...
+
+
+def preserve_fn(
+ f: Union[
+ Callable[..., R],
+ Callable[..., Awaitable[R]],
+ ]
+) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""
- def g(*args, **kwargs):
+ def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs)
return g
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+@overload
+def run_in_background( # type: ignore[misc]
+ f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+ # The `type: ignore[misc]` above suppresses
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+ ...
+
+
+@overload
+def run_in_background(
+ f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+ ...
+
+
+def run_in_background(
+ f: Union[
+ Callable[..., R],
+ Callable[..., Awaitable[R]],
+ ],
+ *args: Any,
+ **kwargs: Any,
+) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -751,6 +804,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
+ # `res` is not a `Deferred` and not a `Coroutine`.
+ # There are no other types of `Awaitable`s we expect to encounter in Synapse.
+ assert not isinstance(res, Awaitable)
+
return defer.succeed(res)
if res.called and not res.paused:
@@ -778,13 +835,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res
-def make_deferred_yieldable(deferred):
- """Given a deferred (or coroutine), make it follow the Synapse logcontext
- rules:
+T = TypeVar("T")
- If the deferred has completed (or is not actually a Deferred), essentially
- does nothing (just returns another completed deferred with the
- result/failure).
+
+def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed, essentially does nothing (just returns another
+ completed deferred with the result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
@@ -792,16 +850,6 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
- if inspect.isawaitable(deferred):
- # If we're given a coroutine we convert it to a deferred so that we
- # run it and find out if it immediately finishes, it it does then we
- # don't need to fiddle with log contexts at all and can return
- # immediately.
- deferred = defer.ensureDeferred(deferred)
-
- if not isinstance(deferred, defer.Deferred):
- return deferred
-
if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery.
@@ -823,7 +871,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+ reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@@ -855,7 +905,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+ reactor: "ISynapseReactor",
+ threadpool: ThreadPool,
+ f: Callable[..., R],
+ *args: Any,
+ **kwargs: Any,
+) -> "defer.Deferred[R]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@@ -897,7 +953,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
- def g():
+ def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 20d23a4260..622445e9f4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ
import attr
from twisted.internet import defer
+from twisted.web.http import Request
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
@@ -219,11 +220,12 @@ class _DummyTagNames:
try:
import opentracing
+ import opentracing.tags
tags = opentracing.tags
except ImportError:
- opentracing = None
- tags = _DummyTagNames
+ opentracing = None # type: ignore[assignment]
+ tags = _DummyTagNames # type: ignore[assignment]
try:
from jaeger_client import Config as JaegerConfig
@@ -366,7 +368,7 @@ def init_tracer(hs: "HomeServer"):
global opentracing
if not hs.config.tracing.opentracer_enabled:
# We don't have a tracer
- opentracing = None
+ opentracing = None # type: ignore[assignment]
return
if not opentracing or not JaegerConfig:
@@ -452,7 +454,7 @@ def start_active_span(
"""
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
return opentracing.tracer.start_active_span(
operation_name,
@@ -477,7 +479,7 @@ def start_active_span_follows_from(
forced, the new span will also have tracing forced.
"""
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(operation_name, references=references)
@@ -490,48 +492,6 @@ def start_active_span_follows_from(
return scope
-def start_active_span_from_request(
- request,
- operation_name,
- references=None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
-):
- """
- Extracts a span context from a Twisted Request.
- args:
- headers (twisted.web.http.Request)
-
- For the other args see opentracing.tracer
-
- returns:
- span_context (opentracing.span.SpanContext)
- """
- # Twisted encodes the values as lists whereas opentracing doesn't.
- # So, we take the first item in the list.
- # Also, twisted uses byte arrays while opentracing expects strings.
-
- if opentracing is None:
- return noop_context_manager()
-
- header_dict = {
- k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
- }
- context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
-
- return opentracing.tracer.start_active_span(
- operation_name,
- child_of=context,
- references=references,
- tags=tags,
- start_time=start_time,
- ignore_active_span=ignore_active_span,
- finish_on_close=finish_on_close,
- )
-
-
def start_active_span_from_edu(
edu_content,
operation_name,
@@ -553,7 +513,7 @@ def start_active_span_from_edu(
references = references or []
if opentracing is None:
- return noop_context_manager()
+ return noop_context_manager() # type: ignore[unreachable]
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
@@ -594,18 +554,21 @@ def active_span():
@ensure_active_span("set a tag")
def set_tag(key, value):
"""Sets a tag on the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
def log_kv(key_values, timestamp=None):
"""Log to the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
def set_operation_name(operation_name):
"""Sets the operation name of the active span"""
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -674,6 +637,7 @@ def inject_header_dict(
span = opentracing.tracer.active_span
carrier: Dict[str, str] = {}
+ assert span is not None
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -716,6 +680,7 @@ def get_active_span_text_map(destination=None):
return {}
carrier: Dict[str, str] = {}
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@@ -731,12 +696,27 @@ def active_span_context_as_string():
"""
carrier: Dict[str, str] = {}
if opentracing:
+ assert opentracing.tracer.active_span is not None
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
return json_encoder.encode(carrier)
+def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+ """Extract an opentracing context from the headers on an HTTP request
+
+ This is useful when we have received an HTTP request from another part of our
+ system, and want to link our spans to those of the remote system.
+ """
+ if not opentracing:
+ return None
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+ return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+
@only_if_tracing
def span_context_from_string(carrier):
"""
@@ -773,7 +753,7 @@ def trace(func=None, opname=None):
def decorator(func):
if opentracing is None:
- return func
+ return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
@@ -864,7 +844,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
"""
if opentracing is None:
- yield
+ yield # type: ignore[unreachable]
return
request_tags = {
@@ -876,10 +856,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
}
request_name = request.request_metrics.name
- if extract_context:
- scope = start_active_span_from_request(request, request_name)
- else:
- scope = start_active_span(request_name)
+ context = span_context_from_request(request) if extract_context else None
+
+ # we configure the scope not to finish the span immediately on exit, and instead
+ # pass the span into the SynapseRequest, which will finish it once we've finished
+ # sending the response to the client.
+ scope = start_active_span(request_name, child_of=context, finish_on_close=False)
+ request.set_opentracing_span(scope.span)
with scope:
inject_response_headers(request.responseHeaders)
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index b1e8e08fe9..db8ca2c049 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -71,7 +71,7 @@ class LogContextScopeManager(ScopeManager):
if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
- return Scope(None, span)
+ return Scope(None, span) # type: ignore[arg-type]
elif ctx.scope is not None:
# We want the logging scope to look exactly the same so we give it
# a blank suffix
|