diff options
author | David Robertson <davidr@element.io> | 2021-10-04 13:15:51 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2021-10-04 14:07:56 +0100 |
commit | 36f47b37a9dad97879a0dfdc1e79fd7db1e78c19 (patch) | |
tree | 8a5cec743861bf686726a7d86622f09abd2a9924 | |
parent | no-untyped-defs for synapse.logging.handlers (diff) | |
download | synapse-github/dmr/synapse.logging-typing.tar.xz |
Easier fn annotations for synapse.logging.context github/dmr/synapse.logging-typing dmr/synapse.logging-typing
Not yet passing `no-untyped-defs`. `make_deferred_yieldable` is tricky and needs more thought.
-rw-r--r-- | synapse/logging/context.py | 83 |
1 files changed, 59 insertions, 24 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 02e5ddd2ef..fc3e514ec8 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -22,21 +22,37 @@ them. See doc/log_contexts.rst for details on how this works. """ +import functools import inspect import logging import threading import typing import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr from typing_extensions import Literal from twisted.internet import defer, threads +from twisted.internet.interfaces import IReactorThreads, ThreadPool if TYPE_CHECKING: from synapse.logging.scopecontextmanager import _LogContextScope +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + logger = logging.getLogger(__name__) try: @@ -66,7 +82,7 @@ except Exception: # a hook which can be set during testing to assert that we aren't abusing logcontexts. -def logcontext_error(msg: str): +def logcontext_error(msg: str) -> None: logger.warning(msg) @@ -220,28 +236,28 @@ class _Sentinel: self.scope = None self.tag = None - def __str__(self): + def __str__(self) -> str: return "sentinel" - def copy_to(self, record): + def copy_to(self, record: "LoggingContext") -> None: pass - def start(self, rusage: "Optional[resource._RUsage]"): + def start(self, rusage: "Optional[resource._RUsage]") -> None: pass - def stop(self, rusage: "Optional[resource._RUsage]"): + def stop(self, rusage: "Optional[resource._RUsage]") -> None: pass - def add_database_transaction(self, duration_sec): + def add_database_transaction(self, duration_sec: float) -> None: pass - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: pass - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: pass - def __bool__(self): + def __bool__(self) -> bool: return False @@ -379,7 +395,12 @@ class LoggingContext: ) return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: @@ -399,10 +420,8 @@ class LoggingContext: # recorded against the correct metrics. self.finished = True - def copy_to(self, record) -> None: - """Copy logging fields from this context to a log record or - another LoggingContext - """ + def copy_to(self, record: "LoggingContext") -> None: + """Copy logging fields from this context to another LoggingContext""" # we track the current request record.request = self.request @@ -575,7 +594,7 @@ class LoggingContextFilter(logging.Filter): record. """ - def __init__(self, request: str = ""): + def __init__(self, request: str = "") -> None: self._default_request = request def filter(self, record: logging.LogRecord) -> Literal[True]: @@ -626,7 +645,12 @@ class PreserveLoggingContext: def __enter__(self) -> None: self._old_context = set_current_context(self._new_context) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: context = set_current_context(self._old_context) if context != self._new_context: @@ -711,16 +735,19 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) -def preserve_fn(f): +def preserve_fn(f: F) -> F: """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): + @functools.wraps(f) + def g(*args: Any, **kwargs: Any) -> Any: return run_in_background(f, *args, **kwargs) - return g + return cast(F, g) -def run_in_background(f, *args, **kwargs) -> defer.Deferred: +def run_in_background( + f: Callable[..., T], *args: Any, **kwargs: Any +) -> "defer.Deferred[T]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -823,7 +850,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: return result -def defer_to_thread(reactor, f, *args, **kwargs): +def defer_to_thread( + reactor: IReactorThreads, f: Callable[..., T], *args: Any, **kwargs: Any +) -> "defer.Deferred[T]": """ Calls the function `f` using a thread from the reactor's default threadpool and returns the result as a Deferred. @@ -855,7 +884,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): +def defer_to_threadpool( + reactor: IReactorThreads, + threadpool: ThreadPool, + f: Callable[..., T], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[T]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles logcontexts correctly. @@ -897,7 +932,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): assert isinstance(curr_context, LoggingContext) parent_context = curr_context - def g(): + def g() -> T: with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) |