summary refs log tree commit diff
path: root/synapse/logging/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/context.py')
-rw-r--r--synapse/logging/context.py83
1 files changed, 59 insertions, 24 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 02e5ddd2ef..fc3e514ec8 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,21 +22,37 @@ them.
 
 See doc/log_contexts.rst for details on how this works.
 """
+import functools
 import inspect
 import logging
 import threading
 import typing
 import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
 from typing_extensions import Literal
 
 from twisted.internet import defer, threads
+from twisted.internet.interfaces import IReactorThreads, ThreadPool
 
 if TYPE_CHECKING:
     from synapse.logging.scopecontextmanager import _LogContextScope
 
+T = TypeVar("T")
+F = TypeVar("F", bound=Callable[..., Any])
+
 logger = logging.getLogger(__name__)
 
 try:
@@ -66,7 +82,7 @@ except Exception:
 
 
 # a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
     logger.warning(msg)
 
 
@@ -220,28 +236,28 @@ class _Sentinel:
         self.scope = None
         self.tag = None
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "sentinel"
 
-    def copy_to(self, record):
+    def copy_to(self, record: "LoggingContext") -> None:
         pass
 
-    def start(self, rusage: "Optional[resource._RUsage]"):
+    def start(self, rusage: "Optional[resource._RUsage]") -> None:
         pass
 
-    def stop(self, rusage: "Optional[resource._RUsage]"):
+    def stop(self, rusage: "Optional[resource._RUsage]") -> None:
         pass
 
-    def add_database_transaction(self, duration_sec):
+    def add_database_transaction(self, duration_sec: float) -> None:
         pass
 
-    def add_database_scheduled(self, sched_sec):
+    def add_database_scheduled(self, sched_sec: float) -> None:
         pass
 
-    def record_event_fetch(self, event_count):
+    def record_event_fetch(self, event_count: int) -> None:
         pass
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return False
 
 
@@ -379,7 +395,12 @@ class LoggingContext:
             )
         return self
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         """Restore the logging context in thread local storage to the state it
         was before this context was entered.
         Returns:
@@ -399,10 +420,8 @@ class LoggingContext:
         # recorded against the correct metrics.
         self.finished = True
 
-    def copy_to(self, record) -> None:
-        """Copy logging fields from this context to a log record or
-        another LoggingContext
-        """
+    def copy_to(self, record: "LoggingContext") -> None:
+        """Copy logging fields from this context to another LoggingContext"""
 
         # we track the current request
         record.request = self.request
@@ -575,7 +594,7 @@ class LoggingContextFilter(logging.Filter):
     record.
     """
 
-    def __init__(self, request: str = ""):
+    def __init__(self, request: str = "") -> None:
         self._default_request = request
 
     def filter(self, record: logging.LogRecord) -> Literal[True]:
@@ -626,7 +645,12 @@ class PreserveLoggingContext:
     def __enter__(self) -> None:
         self._old_context = set_current_context(self._new_context)
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         context = set_current_context(self._old_context)
 
         if context != self._new_context:
@@ -711,16 +735,19 @@ def nested_logging_context(suffix: str) -> LoggingContext:
     )
 
 
-def preserve_fn(f):
+def preserve_fn(f: F) -> F:
     """Function decorator which wraps the function with run_in_background"""
 
-    def g(*args, **kwargs):
+    @functools.wraps(f)
+    def g(*args: Any, **kwargs: Any) -> Any:
         return run_in_background(f, *args, **kwargs)
 
-    return g
+    return cast(F, g)
 
 
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+def run_in_background(
+    f: Callable[..., T], *args: Any, **kwargs: Any
+) -> "defer.Deferred[T]":
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
     deferred returned by the function completes.
@@ -823,7 +850,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     return result
 
 
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+    reactor: IReactorThreads, f: Callable[..., T], *args: Any, **kwargs: Any
+) -> "defer.Deferred[T]":
     """
     Calls the function `f` using a thread from the reactor's default threadpool and
     returns the result as a Deferred.
@@ -855,7 +884,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
     return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
 
 
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+    reactor: IReactorThreads,
+    threadpool: ThreadPool,
+    f: Callable[..., T],
+    *args: Any,
+    **kwargs: Any,
+) -> "defer.Deferred[T]":
     """
     A wrapper for twisted.internet.threads.deferToThreadpool, which handles
     logcontexts correctly.
@@ -897,7 +932,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         assert isinstance(curr_context, LoggingContext)
         parent_context = curr_context
 
-    def g():
+    def g() -> T:
         with LoggingContext(str(curr_context), parent_context=parent_context):
             return f(*args, **kwargs)