diff options
Diffstat (limited to 'synapse/logging/context.py')
-rw-r--r-- | synapse/logging/context.py | 437 |
1 files changed, 283 insertions, 154 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..8b9c4e38bd 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,13 +23,20 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types -from typing import Any, List +import warnings +from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union + +from typing_extensions import Literal from twisted.internet import defer, threads +if TYPE_CHECKING: + from synapse.logging.scopecontextmanager import _LogContextScope + logger = logging.getLogger(__name__) try: @@ -45,7 +52,7 @@ try: is_thread_resource_usage_supported = True - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return resource.getrusage(RUSAGE_THREAD) @@ -54,7 +61,7 @@ except Exception: # won't track resource usage. is_thread_resource_usage_supported = False - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return None @@ -90,7 +97,7 @@ class ContextResourceUsage(object): "evt_db_fetch_count", ] - def __init__(self, copy_from=None): + def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None: """Create a new ContextResourceUsage Args: @@ -100,27 +107,28 @@ class ContextResourceUsage(object): if copy_from is None: self.reset() else: - self.ru_utime = copy_from.ru_utime - self.ru_stime = copy_from.ru_stime - self.db_txn_count = copy_from.db_txn_count + # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now + self.ru_utime = copy_from.ru_utime # type: float + self.ru_stime = copy_from.ru_stime # type: float + self.db_txn_count = copy_from.db_txn_count # type: int - self.db_txn_duration_sec = copy_from.db_txn_duration_sec - self.db_sched_duration_sec = copy_from.db_sched_duration_sec - self.evt_db_fetch_count = copy_from.evt_db_fetch_count + self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float + self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float + self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int - def copy(self): + def copy(self) -> "ContextResourceUsage": return ContextResourceUsage(copy_from=self) - def reset(self): + def reset(self) -> None: self.ru_stime = 0.0 self.ru_utime = 0.0 self.db_txn_count = 0 - self.db_txn_duration_sec = 0 - self.db_sched_duration_sec = 0 + self.db_txn_duration_sec = 0.0 + self.db_sched_duration_sec = 0.0 self.evt_db_fetch_count = 0 - def __repr__(self): + def __repr__(self) -> str: return ( "<ContextResourceUsage ru_stime='%r', ru_utime='%r', " "db_txn_count='%r', db_txn_duration_sec='%r', " @@ -134,7 +142,7 @@ class ContextResourceUsage(object): self.evt_db_fetch_count, ) - def __iadd__(self, other): + def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": """Add another ContextResourceUsage's stats to this one's. Args: @@ -148,7 +156,7 @@ class ContextResourceUsage(object): self.evt_db_fetch_count += other.evt_db_fetch_count return self - def __isub__(self, other): + def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": self.ru_utime -= other.ru_utime self.ru_stime -= other.ru_stime self.db_txn_count -= other.db_txn_count @@ -157,17 +165,67 @@ class ContextResourceUsage(object): self.evt_db_fetch_count -= other.evt_db_fetch_count return self - def __add__(self, other): + def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": res = ContextResourceUsage(copy_from=self) res += other return res - def __sub__(self, other): + def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": res = ContextResourceUsage(copy_from=self) res -= other return res +LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] + + +class _Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + + def __init__(self) -> None: + # Minimal set for compatibility with LoggingContext + self.previous_context = None + self.finished = False + self.request = None + self.scope = None + self.tag = None + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + + def start(self, rusage: "Optional[resource._RUsage]"): + pass + + def stop(self, rusage: "Optional[resource._RUsage]"): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + +SENTINEL_CONTEXT = _Sentinel() + + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. @@ -189,67 +247,32 @@ class LoggingContext(object): "_resource_usage", "usage_start", "main_thread", - "alive", + "finished", "request", "tag", "scope", ] - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = [] # type: List[Any] - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - - def __init__(self, name=None, parent_context=None, request=None): - self.previous_context = LoggingContext.current_context() + def __init__(self, name=None, parent_context=None, request=None) -> None: + self.previous_context = current_context() self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() - # If alive has the thread resource usage when the logcontext last - # became active. - self.usage_start = None + # The thread resource usage when the logcontext became active. None + # if the context is not currently active. + self.usage_start = None # type: Optional[resource._RUsage] self.main_thread = get_thread_id() self.request = None self.tag = "" - self.alive = True - self.scope = None + self.scope = None # type: Optional[_LogContextScope] + + # keep track of whether we have hit the __exit__ block for this context + # (suggesting that the the thing that created the context thinks it should + # be finished, and that re-activating it would suggest an error). + self.finished = False self.parent_context = parent_context @@ -260,76 +283,83 @@ class LoggingContext(object): # the request param overrides the request from the parent context self.request = request - def __str__(self): + def __str__(self) -> str: if self.request: return str(self.request) return "%s@%x" % (self.name, id(self)) @classmethod - def current_context(cls): + def current_context(cls) -> LoggingContextOrSentinel: """Get the current logging context from thread local storage + This exists for backwards compatibility. ``current_context()`` should be + called directly. + Returns: LoggingContext: the current logging context """ - return getattr(cls.thread_local, "current_context", cls.sentinel) + warnings.warn( + "synapse.logging.context.LoggingContext.current_context() is deprecated " + "in favor of synapse.logging.context.current_context().", + DeprecationWarning, + stacklevel=2, + ) + return current_context() @classmethod - def set_current_context(cls, context): + def set_current_context( + cls, context: LoggingContextOrSentinel + ) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage + + This exists for backwards compatibility. ``set_current_context()`` should be + called directly. + Args: context(LoggingContext): The context to activate. Returns: The context that was previously active """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current + warnings.warn( + "synapse.logging.context.LoggingContext.set_current_context() is deprecated " + "in favor of synapse.logging.context.set_current_context().", + DeprecationWarning, + stacklevel=2, + ) + return set_current_context(context) - def __enter__(self): + def __enter__(self) -> "LoggingContext": """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) + old_context = set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, ) - self.alive = True - return self - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: None to avoid suppressing any exceptions that were thrown. """ - current = self.set_current_context(self.previous_context) + current = set_current_context(self.previous_context) if current is not self: - if current is self.sentinel: + if current is SENTINEL_CONTEXT: logger.warning("Expected logging context %s was lost", self) else: logger.warning( "Expected logging context %s but found %s", self, current ) - self.previous_context = None - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # the fact that we are here suggests that the caller thinks that everything + # is done and dusted for this logcontext, and further activity will not get + # recorded against the correct metrics. + self.finished = True - def copy_to(self, record): + def copy_to(self, record) -> None: """Copy logging fields from this context to a log record or another LoggingContext """ @@ -340,44 +370,72 @@ class LoggingContext(object): # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record): + def copy_to_twisted_log_entry(self, record) -> None: """ Copy logging fields from this context to a Twisted log record. """ record["request"] = self.request record["scope"] = self.scope - def start(self): + def start(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is currently running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching to this logcontext. May be None if this platform doesn't + support getrusuage. + """ if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) return + if self.finished: + logger.warning("Re-starting finished log context %s", self) + # If we haven't already started record the thread resource usage so # far - if not self.usage_start: - self.usage_start = get_thread_resource_usage() + if self.usage_start: + logger.warning("Re-starting already-active log context %s", self) + else: + self.usage_start = rusage - def stop(self): - if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return + def stop(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is no longer running. - # When we stop, let's record the cpu used since we started - if not self.usage_start: - # Log a warning on platforms that support thread usage tracking - if is_thread_resource_usage_supported: + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching away from this logcontext. May be None if this platform + doesn't support getrusuage. + """ + + try: + if get_thread_id() != self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + if not rusage: + return + + # Record the cpu used since we started + if not self.usage_start: logger.warning( - "Called stop on logcontext %s without calling start", self + "Called stop on logcontext %s without recording a start rusage", + self, ) - return + return - utime_delta, stime_delta = self._get_cputime() - self._resource_usage.ru_utime += utime_delta - self._resource_usage.ru_stime += stime_delta - - self.usage_start = None + utime_delta, stime_delta = self._get_cputime(rusage) + self.add_cputime(utime_delta, stime_delta) + finally: + self.usage_start = None - def get_resource_usage(self): + def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. Returns: @@ -390,19 +448,24 @@ class LoggingContext(object): # If we are on the correct thread and we're currently running then we # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: - utime_delta, stime_delta = self._get_cputime() + if self.usage_start and is_main_thread: + rusage = get_thread_resource_usage() + assert rusage is not None + utime_delta, stime_delta = self._get_cputime(rusage) res.ru_utime += utime_delta res.ru_stime += stime_delta return res - def _get_cputime(self): - """Get the cpu usage time so far + def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]: + """Get the cpu usage time between start() and the given rusage + + Args: + rusage: the current resource usage Returns: Tuple[float, float]: seconds in user mode, seconds in system mode """ - current = get_thread_resource_usage() + assert self.usage_start is not None utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -426,30 +489,52 @@ class LoggingContext(object): return utime_delta, stime_delta - def add_database_transaction(self, duration_sec): + def add_cputime(self, utime_delta: float, stime_delta: float) -> None: + """Update the CPU time usage of this context (and any parents, recursively). + + Args: + utime_delta: additional user time, in seconds, spent in this context. + stime_delta: additional system time, in seconds, spent in this context. + """ + self._resource_usage.ru_utime += utime_delta + self._resource_usage.ru_stime += stime_delta + if self.parent_context: + self.parent_context.add_cputime(utime_delta, stime_delta) + + def add_database_transaction(self, duration_sec: float) -> None: + """Record the use of a database transaction and the length of time it took. + + Args: + duration_sec: The number of seconds the database transaction took. + """ if duration_sec < 0: raise ValueError("DB txn time can only be non-negative") self._resource_usage.db_txn_count += 1 self._resource_usage.db_txn_duration_sec += duration_sec + if self.parent_context: + self.parent_context.add_database_transaction(duration_sec) - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: """Record a use of the database pool Args: - sched_sec (float): number of seconds it took us to get a - connection + sched_sec: number of seconds it took us to get a connection """ if sched_sec < 0: raise ValueError("DB scheduling time can only be non-negative") self._resource_usage.db_sched_duration_sec += sched_sec + if self.parent_context: + self.parent_context.add_database_scheduled(sched_sec) - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: """Record a number of events being fetched from the db Args: - event_count (int): number of events being fetched + event_count: number of events being fetched """ self._resource_usage.evt_db_fetch_count += event_count + if self.parent_context: + self.parent_context.record_event_fetch(event_count) class LoggingContextFilter(logging.Filter): @@ -460,15 +545,15 @@ class LoggingContextFilter(logging.Filter): missing fields """ - def __init__(self, **defaults): + def __init__(self, **defaults) -> None: self.defaults = defaults - def filter(self, record): + def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. Returns: True to include the record in the log output. """ - context = LoggingContext.current_context() + context = current_context() for key, value in self.defaults.items(): setattr(record, key, value) @@ -488,26 +573,24 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context=None): - if new_context is None: - new_context = LoggingContext.sentinel + def __init__( + self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT + ) -> None: self.new_context = new_context - def __enter__(self): + def __enter__(self) -> None: """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) + self.current_context = set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback) -> None: """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) + context = set_current_context(self.current_context) if context != self.new_context: - if context is LoggingContext.sentinel: + if not context: logger.warning("Expected logging context %s was lost", self.new_context) else: logger.warning( @@ -516,12 +599,42 @@ class PreserveLoggingContext(object): context, ) - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) +_thread_local = threading.local() +_thread_local.current_context = SENTINEL_CONTEXT + + +def current_context() -> LoggingContextOrSentinel: + """Get the current logging context from thread local storage""" + return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) -def nested_logging_context(suffix, parent_context=None): + +def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + # everything blows up if we allow current_context to be set to None, so sanity-check + # that now. + if context is None: + raise TypeError("'context' argument may not be None") + + current = current_context() + + if current is not context: + rusage = get_thread_resource_usage() + current.stop(rusage) + _thread_local.current_context = context + context.start(rusage) + + return current + + +def nested_logging_context( + suffix: str, parent_context: Optional[LoggingContext] = None +) -> LoggingContext: """Creates a new logging context as a child of another. The nested logging context will have a 'request' made up of the parent context's @@ -542,10 +655,12 @@ def nested_logging_context(suffix, parent_context=None): Returns: LoggingContext: new logging context. """ - if parent_context is None: - parent_context = LoggingContext.current_context() + if parent_context is not None: + context = parent_context # type: LoggingContextOrSentinel + else: + context = current_context() return LoggingContext( - parent_context=parent_context, request=parent_context.request + "-" + suffix + parent_context=context, request=str(context.request) + "-" + suffix ) @@ -567,12 +682,15 @@ def run_in_background(f, *args, **kwargs): yield or await on (for instance because you want to pass it to deferred.gatherResults()). + If f returns a Coroutine object, it will be wrapped into a Deferred (which will have + the side effect of executing the coroutine). + Note that if you completely discard the result, you should make sure that `f` doesn't raise any deferred exceptions, otherwise a scary-looking CRITICAL error about an unhandled error will be logged without much indication about where it came from. """ - current = LoggingContext.current_context() + current = current_context() try: res = f(*args, **kwargs) except: # noqa: E722 @@ -593,7 +711,7 @@ def run_in_background(f, *args, **kwargs): # The function may have reset the context before returning, so # we need to restore it now. - ctx = LoggingContext.set_current_context(current) + ctx = set_current_context(current) # The original context will be restored when the deferred # completes, but there is nothing waiting for it, so it will @@ -612,7 +730,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +743,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred @@ -634,14 +760,17 @@ def make_deferred_yieldable(deferred): # ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + prev_context = set_current_context(SENTINEL_CONTEXT) deferred.addBoth(_set_context_cb, prev_context) return deferred -def _set_context_cb(result, context): +ResultT = TypeVar("ResultT") + + +def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) + set_current_context(context) return result @@ -709,7 +838,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = LoggingContext.current_context() + logcontext = current_context() def g(): with LoggingContext(parent_context=logcontext): |