diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 02e5ddd2ef..fc3e514ec8 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,21 +22,37 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import functools
import inspect
import logging
import threading
import typing
import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
+from twisted.internet.interfaces import IReactorThreads, ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
+T = TypeVar("T")
+F = TypeVar("F", bound=Callable[..., Any])
+
logger = logging.getLogger(__name__)
try:
@@ -66,7 +82,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
logger.warning(msg)
@@ -220,28 +236,28 @@ class _Sentinel:
self.scope = None
self.tag = None
- def __str__(self):
+ def __str__(self) -> str:
return "sentinel"
- def copy_to(self, record):
+ def copy_to(self, record: "LoggingContext") -> None:
pass
- def start(self, rusage: "Optional[resource._RUsage]"):
+ def start(self, rusage: "Optional[resource._RUsage]") -> None:
pass
- def stop(self, rusage: "Optional[resource._RUsage]"):
+ def stop(self, rusage: "Optional[resource._RUsage]") -> None:
pass
- def add_database_transaction(self, duration_sec):
+ def add_database_transaction(self, duration_sec: float) -> None:
pass
- def add_database_scheduled(self, sched_sec):
+ def add_database_scheduled(self, sched_sec: float) -> None:
pass
- def record_event_fetch(self, event_count):
+ def record_event_fetch(self, event_count: int) -> None:
pass
- def __bool__(self):
+ def __bool__(self) -> bool:
return False
@@ -379,7 +395,12 @@ class LoggingContext:
)
return self
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@@ -399,10 +420,8 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True
- def copy_to(self, record) -> None:
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
+ def copy_to(self, record: "LoggingContext") -> None:
+ """Copy logging fields from this context to another LoggingContext"""
# we track the current request
record.request = self.request
@@ -575,7 +594,7 @@ class LoggingContextFilter(logging.Filter):
record.
"""
- def __init__(self, request: str = ""):
+ def __init__(self, request: str = "") -> None:
self._default_request = request
def filter(self, record: logging.LogRecord) -> Literal[True]:
@@ -626,7 +645,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@@ -711,16 +735,19 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
-def preserve_fn(f):
+def preserve_fn(f: F) -> F:
"""Function decorator which wraps the function with run_in_background"""
- def g(*args, **kwargs):
+ @functools.wraps(f)
+ def g(*args: Any, **kwargs: Any) -> Any:
return run_in_background(f, *args, **kwargs)
- return g
+ return cast(F, g)
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+def run_in_background(
+ f: Callable[..., T], *args: Any, **kwargs: Any
+) -> "defer.Deferred[T]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -823,7 +850,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+ reactor: IReactorThreads, f: Callable[..., T], *args: Any, **kwargs: Any
+) -> "defer.Deferred[T]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@@ -855,7 +884,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+ reactor: IReactorThreads,
+ threadpool: ThreadPool,
+ f: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+) -> "defer.Deferred[T]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@@ -897,7 +932,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
- def g():
+ def g() -> T:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)
|