summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-06-03 14:16:35 +0100
committerErik Johnston <erik@matrix.org>2024-07-02 16:24:05 +0100
commit3c3c0dd419aac1b46d6838094863d5895a183b05 (patch)
tree32ddd37eaaf75a34d97ebeb679fa739d63f751d9 /synapse
parentMerge remote-tracking branch 'origin/release-v1.110' into develop (diff)
downloadsynapse-3c3c0dd419aac1b46d6838094863d5895a183b05.tar.xz
WIP
Diffstat (limited to 'synapse')
-rw-r--r--synapse/logging/context.py83
1 files changed, 79 insertions, 4 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 4650b60962..5e7cc4c3b7 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -33,9 +33,12 @@ import logging
 import threading
 import typing
 import warnings
+from collections.abc import Coroutine, Generator
+from contextvars import ContextVar
 from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Optional,
@@ -657,13 +660,14 @@ class PreserveLoggingContext:
                 )
 
 
-_thread_local = threading.local()
-_thread_local.current_context = SENTINEL_CONTEXT
+_CURRENT_CONTEXT_VAR: ContextVar[LoggingContextOrSentinel] = ContextVar(
+    "current_context", default=SENTINEL_CONTEXT
+)
 
 
 def current_context() -> LoggingContextOrSentinel:
     """Get the current logging context from thread local storage"""
-    return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
+    return _CURRENT_CONTEXT_VAR.get()
 
 
 def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
@@ -684,7 +688,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     if current is not context:
         rusage = get_thread_resource_usage()
         current.stop(rusage)
-        _thread_local.current_context = context
+        _CURRENT_CONTEXT_VAR.set(context)
         context.start(rusage)
 
     return current
@@ -971,3 +975,74 @@ def defer_to_threadpool(
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
+
+
+_T = TypeVar("_T")
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _ResourceTracker(Generator[defer.Deferred[Any], Any, _T]):
+    gen: Generator[defer.Deferred[Any], Any, _T]
+
+    def send(self, val: Any) -> defer.Deferred[_T]:
+        try:
+            return self.gen.send(val)
+        finally:
+            pass
+
+    @overload
+    def throw(
+        self,
+        a: Type[BaseException],
+        b: object = ...,
+        c: Optional[TracebackType] = ...,
+        /,
+    ) -> defer.Deferred[Any]: ...
+
+    @overload
+    def throw(
+        self, a: BaseException, v: None = ..., c: Optional[TracebackType] = ..., /
+    ) -> defer.Deferred[Any]: ...
+
+    def throw(self, a: Any, b: Any = None, c: Any = None) -> defer.Deferred[Any]:
+        try:
+            return self.throw(a, b, c)
+        finally:
+            pass
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _ResourceTracker2(Coroutine[defer.Deferred[Any], Any, _T]):
+    gen: Coroutine[defer.Deferred[Any], Any, _T]
+
+    def send(self, val: Any) -> defer.Deferred[_T]:
+        try:
+            return self.gen.send(val)
+        finally:
+            pass
+
+    @overload
+    def throw(
+        self,
+        a: Type[BaseException],
+        b: object = ...,
+        c: Optional[TracebackType] = ...,
+        /,
+    ) -> defer.Deferred[Any]: ...
+
+    @overload
+    def throw(
+        self, a: BaseException, v: None = ..., c: Optional[TracebackType] = ..., /
+    ) -> defer.Deferred[Any]: ...
+
+    def throw(self, a: Any, b: Any = None, c: Any = None) -> defer.Deferred[Any]:
+        try:
+            return self.throw(a, b, c)
+        finally:
+            pass
+
+    def __await__(self) -> Generator[defer.Deferred[Any], Any, _T]:
+        return _ResourceTracker(self.gen.__await__())
+
+    def close(self) -> None:
+        return self.gen.close()