diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index fb937b3f28..643492ceaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
- transport = attr.ib(type=ITransport)
+ # This is essentially ITCPTransport, but that is missing certain fields
+ # (connected and registerProducer) which are part of the implementation.
+ transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -121,7 +124,9 @@ class RemoteHandler(logging.Handler):
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
+ endpoint = TCP4ClientEndpoint(
+ _reactor, self.host, self.port
+ ) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
@@ -147,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -161,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
+ # Force recognising transport as a Connection and not the more
+ # generic ITransport.
+ transport = result.transport # type: Connection # type: ignore
+
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and result.transport is self._producer.transport:
+ if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -174,13 +181,17 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
- buffer=self._buffer, transport=result.transport, format=self.format,
+ buffer=self._buffer,
+ transport=transport,
+ format=self.format,
)
- result.transport.registerProducer(self._producer, True)
+ transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
- self._connection_waiter.addCallbacks(writer, fail)
+ deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred.addCallbacks(writer, fail)
+ self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 14d9c104c2..3e054f615c 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -60,7 +60,10 @@ def parse_drain_configs(
)
# Either use the default formatter or the tersejson one.
- if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
+ if logging_type in (
+ DrainType.CONSOLE_JSON,
+ DrainType.FILE_JSON,
+ ):
formatter = "json" # type: Optional[str]
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
@@ -131,7 +134,9 @@ def parse_drain_configs(
)
-def setup_structured_logging(log_config: dict,) -> dict:
+def setup_structured_logging(
+ log_config: dict,
+) -> dict:
"""
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers.
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..03cf3c2b8e 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -256,7 +252,12 @@ class LoggingContext:
"scope",
]
- def __init__(self, name=None, parent_context=None, request=None) -> None:
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ parent_context: "Optional[LoggingContext]" = None,
+ request: Optional[str] = None,
+ ) -> None:
self.previous_context = current_context()
self.name = name
@@ -337,7 +338,10 @@ class LoggingContext:
if self.previous_context != old_context:
logcontext_error(
"Expected previous context %r, found %r"
- % (self.previous_context, old_context,)
+ % (
+ self.previous_context,
+ old_context,
+ )
)
return self
@@ -372,13 +376,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,28 +539,25 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
- def filter(self, record) -> Literal[True]:
+ def filter(self, record: logging.LogRecord) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request # type: ignore
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request # type: ignore
return True
@@ -571,7 +565,7 @@ class LoggingContextFilter(logging.Filter):
class PreserveLoggingContext:
"""Context manager which replaces the logging context
- The previous logging context is restored on exit."""
+ The previous logging context is restored on exit."""
__slots__ = ["_old_context", "_new_context"]
@@ -594,7 +588,10 @@ class PreserveLoggingContext:
else:
logcontext_error(
"Expected logging context %s but found %s"
- % (self._new_context, context,)
+ % (
+ self._new_context,
+ context,
+ )
)
@@ -630,9 +627,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
return current
-def nested_logging_context(
- suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
@@ -646,20 +641,23 @@ def nested_logging_context(
# ... do stuff
Args:
- suffix (str): suffix to add to the parent context's 'request'.
- parent_context (LoggingContext|None): parent context. Will use the current context
- if None.
+ suffix: suffix to add to the parent context's 'request'.
Returns:
LoggingContext: new logging context.
"""
- if parent_context is not None:
- context = parent_context # type: LoggingContextOrSentinel
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Starting nested logging context from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+ prefix = ""
else:
- context = current_context()
- return LoggingContext(
- parent_context=context, request=str(context.request) + "-" + suffix
- )
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
+ prefix = str(parent_context.request)
+ return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
def preserve_fn(f):
@@ -671,7 +669,7 @@ def preserve_fn(f):
return g
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -691,7 +689,7 @@ def run_in_background(f, *args, **kwargs):
current = current_context()
try:
res = f(*args, **kwargs)
- except: # noqa: E722
+ except Exception:
# the assumption here is that the caller doesn't want to be disturbed
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
@@ -699,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
+ # At this point we should have a Deferred, if not then f was a synchronous
+ # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
- return res
+ return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
@@ -836,10 +836,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
- logcontext = current_context()
+ curr_context = current_context()
+ if not curr_context:
+ logger.warning(
+ "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
def g():
- with LoggingContext(parent_context=logcontext):
+ with LoggingContext(parent_context=parent_context):
return f(*args, **kwargs)
return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..aa146e8bb8 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,7 @@ import inspect
import logging
import re
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Type
+from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
import attr
@@ -238,8 +238,7 @@ try:
@attr.s(slots=True, frozen=True)
class _WrappedRustReporter:
- """Wrap the reporter to ensure `report_span` never throws.
- """
+ """Wrap the reporter to ensure `report_span` never throws."""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@@ -263,7 +262,7 @@ logger = logging.getLogger(__name__)
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
-_homeserver_whitelist = None
+_homeserver_whitelist = None # type: Optional[Pattern[str]]
# Util methods
@@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"):
- """Set the whitelists and initialise the JaegerClient tracer
- """
+ """Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.opentracer_enabled:
# We don't have a tracer
@@ -384,7 +382,7 @@ def whitelisted_homeserver(destination):
Args:
destination (str)
- """
+ """
if _homeserver_whitelist:
return _homeserver_whitelist.match(destination)
@@ -791,7 +789,7 @@ def tag_args(func):
@wraps(func)
def _tag_args_inner(*args, **kwargs):
- argspec = inspect.getargspec(func)
+ argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index becf66dd86..fd3543ab04 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
def log_function(f):
- """ Function decorator that logs every call to that function.
- """
+ """Function decorator that logs every call to that function."""
func_name = f.__name__
@wraps(f)
|