diff options
author | Erik Johnston <erik@matrix.org> | 2024-06-03 14:16:35 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2024-07-02 16:24:05 +0100 |
commit | 3c3c0dd419aac1b46d6838094863d5895a183b05 (patch) | |
tree | 32ddd37eaaf75a34d97ebeb679fa739d63f751d9 /synapse | |
parent | Merge remote-tracking branch 'origin/release-v1.110' into develop (diff) | |
download | synapse-3c3c0dd419aac1b46d6838094863d5895a183b05.tar.xz |
WIP
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/logging/context.py | 83 |
1 files changed, 79 insertions, 4 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 4650b60962..5e7cc4c3b7 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -33,9 +33,12 @@ import logging import threading import typing import warnings +from collections.abc import Coroutine, Generator +from contextvars import ContextVar from types import TracebackType from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Optional, @@ -657,13 +660,14 @@ class PreserveLoggingContext: ) -_thread_local = threading.local() -_thread_local.current_context = SENTINEL_CONTEXT +_CURRENT_CONTEXT_VAR: ContextVar[LoggingContextOrSentinel] = ContextVar( + "current_context", default=SENTINEL_CONTEXT +) def current_context() -> LoggingContextOrSentinel: """Get the current logging context from thread local storage""" - return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) + return _CURRENT_CONTEXT_VAR.get() def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: @@ -684,7 +688,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe if current is not context: rusage = get_thread_resource_usage() current.stop(rusage) - _thread_local.current_context = context + _CURRENT_CONTEXT_VAR.set(context) context.start(rusage) return current @@ -971,3 +975,74 @@ def defer_to_threadpool( return f(*args, **kwargs) return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) + + +_T = TypeVar("_T") + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _ResourceTracker(Generator[defer.Deferred[Any], Any, _T]): + gen: Generator[defer.Deferred[Any], Any, _T] + + def send(self, val: Any) -> defer.Deferred[_T]: + try: + return self.gen.send(val) + finally: + pass + + @overload + def throw( + self, + a: Type[BaseException], + b: object = ..., + c: Optional[TracebackType] = ..., + /, + ) -> defer.Deferred[Any]: ... + + @overload + def throw( + self, a: BaseException, v: None = ..., c: Optional[TracebackType] = ..., / + ) -> defer.Deferred[Any]: ... + + def throw(self, a: Any, b: Any = None, c: Any = None) -> defer.Deferred[Any]: + try: + return self.throw(a, b, c) + finally: + pass + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _ResourceTracker2(Coroutine[defer.Deferred[Any], Any, _T]): + gen: Coroutine[defer.Deferred[Any], Any, _T] + + def send(self, val: Any) -> defer.Deferred[_T]: + try: + return self.gen.send(val) + finally: + pass + + @overload + def throw( + self, + a: Type[BaseException], + b: object = ..., + c: Optional[TracebackType] = ..., + /, + ) -> defer.Deferred[Any]: ... + + @overload + def throw( + self, a: BaseException, v: None = ..., c: Optional[TracebackType] = ..., / + ) -> defer.Deferred[Any]: ... + + def throw(self, a: Any, b: Any = None, c: Any = None) -> defer.Deferred[Any]: + try: + return self.throw(a, b, c) + finally: + pass + + def __await__(self) -> Generator[defer.Deferred[Any], Any, _T]: + return _ResourceTracker(self.gen.__await__()) + + def close(self) -> None: + return self.gen.close() |