diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 3220e985a9..7372450b45 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"):
return
- context = LoggingContext.current_context()
+ context = current_context()
# Copy the context information to the log event.
if context is not None:
@@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
def parse_drain_configs(
- drains: dict
+ drains: dict,
) -> typing.Generator[DrainConfiguration, None, None]:
"""
Parse the drain configurations.
@@ -261,6 +261,18 @@ def parse_drain_configs(
)
+class StoppableLogPublisher(LogPublisher):
+ """
+ A log publisher that can tell its observers to shut down any external
+ communications.
+ """
+
+ def stop(self):
+ for obs in self._observers:
+ if hasattr(obs, "stop"):
+ obs.stop()
+
+
def setup_structured_logging(
hs,
config,
@@ -336,7 +348,7 @@ def setup_structured_logging(
# We should never get here, but, just in case, throw an error.
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
- publisher = LogPublisher(*observers)
+ publisher = StoppableLogPublisher(*observers)
log_filter = LogLevelFilterPredicate()
for namespace, namespace_config in log_config.get(
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 0ebbde06f2..c0b9384189 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -17,25 +17,29 @@
Log formatters that output terse JSON.
"""
+import json
import sys
+import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing import IO
+from typing import IO, Optional
import attr
-from simplejson import dumps
from zope.interface import implementer
from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
+from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, ILogObserver, Logger
-from twisted.python.failure import Failure
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
@@ -141,19 +145,57 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata)
- return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+ return _encoder.encode(flattened) + "\n"
return FileLogObserver(outFile, formatEvent)
@attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ """
+
+ transport = attr.ib(type=ITransport)
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = deque()
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ event = self._buffer.popleft()
+ self.transport.write(_encoder.encode(event).encode("utf8"))
+ self.transport.write(b"\n")
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
@implementer(ILogObserver)
class TerseJSONToTCPLogObserver(object):
"""
An IObserver that writes JSON logs to a TCP target.
Args:
- hs (HomeServer): The Homeserver that is being logged for.
+ hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log entry.
@@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object):
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _writer = attr.ib(default=None)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None:
@@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object):
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
+ self._connect()
- def _write_loop(self) -> None:
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
"""
- Implement the write loop.
+ Triggers an attempt to connect then write to the remote if not already writing.
"""
- if self._writer:
+ if self._connection_waiter:
return
- self._writer = self._service.whenConnected()
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
- @self._writer.addBoth
+ @self._connection_waiter.addCallback
def writer(r):
- if isinstance(r, Failure):
- r.printTraceback(file=sys.__stderr__)
- self._writer = None
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
+ self._connection_waiter = None
return
- try:
- for event in self._buffer:
- r.transport.write(
- dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
- "utf8"
- )
- )
- r.transport.write(b"\n")
- self._buffer.clear()
- except Exception as e:
- sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
-
- self._writer = False
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
@@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object):
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
- self._write_loop()
+ self._connect()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 370000e377..8b9c4e38bd 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,13 +23,20 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import inspect
import logging
import threading
import types
-from typing import Any, List
+import warnings
+from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+
+from typing_extensions import Literal
from twisted.internet import defer, threads
+if TYPE_CHECKING:
+ from synapse.logging.scopecontextmanager import _LogContextScope
+
logger = logging.getLogger(__name__)
try:
@@ -45,7 +52,7 @@ try:
is_thread_resource_usage_supported = True
- def get_thread_resource_usage():
+ def get_thread_resource_usage() -> "Optional[resource._RUsage]":
return resource.getrusage(RUSAGE_THREAD)
@@ -54,7 +61,7 @@ except Exception:
# won't track resource usage.
is_thread_resource_usage_supported = False
- def get_thread_resource_usage():
+ def get_thread_resource_usage() -> "Optional[resource._RUsage]":
return None
@@ -90,7 +97,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count",
]
- def __init__(self, copy_from=None):
+ def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage
Args:
@@ -100,27 +107,28 @@ class ContextResourceUsage(object):
if copy_from is None:
self.reset()
else:
- self.ru_utime = copy_from.ru_utime
- self.ru_stime = copy_from.ru_stime
- self.db_txn_count = copy_from.db_txn_count
+ # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
+ self.ru_utime = copy_from.ru_utime # type: float
+ self.ru_stime = copy_from.ru_stime # type: float
+ self.db_txn_count = copy_from.db_txn_count # type: int
- self.db_txn_duration_sec = copy_from.db_txn_duration_sec
- self.db_sched_duration_sec = copy_from.db_sched_duration_sec
- self.evt_db_fetch_count = copy_from.evt_db_fetch_count
+ self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
+ self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
+ self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
- def copy(self):
+ def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self)
- def reset(self):
+ def reset(self) -> None:
self.ru_stime = 0.0
self.ru_utime = 0.0
self.db_txn_count = 0
- self.db_txn_duration_sec = 0
- self.db_sched_duration_sec = 0
+ self.db_txn_duration_sec = 0.0
+ self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0
- def __repr__(self):
+ def __repr__(self) -> str:
return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
@@ -134,7 +142,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count,
)
- def __iadd__(self, other):
+ def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's.
Args:
@@ -148,7 +156,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count += other.evt_db_fetch_count
return self
- def __isub__(self, other):
+ def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count
@@ -157,17 +165,67 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count -= other.evt_db_fetch_count
return self
- def __add__(self, other):
+ def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res += other
return res
- def __sub__(self, other):
+ def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res -= other
return res
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+ """Sentinel to represent the root context"""
+
+ __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+ def __init__(self) -> None:
+ # Minimal set for compatibility with LoggingContext
+ self.previous_context = None
+ self.finished = False
+ self.request = None
+ self.scope = None
+ self.tag = None
+
+ def __str__(self):
+ return "sentinel"
+
+ def copy_to(self, record):
+ pass
+
+ def copy_to_twisted_log_entry(self, record):
+ record["request"] = None
+ record["scope"] = None
+
+ def start(self, rusage: "Optional[resource._RUsage]"):
+ pass
+
+ def stop(self, rusage: "Optional[resource._RUsage]"):
+ pass
+
+ def add_database_transaction(self, duration_sec):
+ pass
+
+ def add_database_scheduled(self, sched_sec):
+ pass
+
+ def record_event_fetch(self, event_count):
+ pass
+
+ def __nonzero__(self):
+ return False
+
+ __bool__ = __nonzero__ # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
+
+
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
@@ -189,67 +247,32 @@ class LoggingContext(object):
"_resource_usage",
"usage_start",
"main_thread",
- "alive",
+ "finished",
"request",
"tag",
"scope",
]
- thread_local = threading.local()
-
- class Sentinel(object):
- """Sentinel to represent the root context"""
-
- __slots__ = [] # type: List[Any]
-
- def __str__(self):
- return "sentinel"
-
- def copy_to(self, record):
- pass
-
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def add_database_transaction(self, duration_sec):
- pass
-
- def add_database_scheduled(self, sched_sec):
- pass
-
- def record_event_fetch(self, event_count):
- pass
-
- def __nonzero__(self):
- return False
-
- __bool__ = __nonzero__ # python3
-
- sentinel = Sentinel()
-
- def __init__(self, name=None, parent_context=None, request=None):
- self.previous_context = LoggingContext.current_context()
+ def __init__(self, name=None, parent_context=None, request=None) -> None:
+ self.previous_context = current_context()
self.name = name
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
- # If alive has the thread resource usage when the logcontext last
- # became active.
- self.usage_start = None
+ # The thread resource usage when the logcontext became active. None
+ # if the context is not currently active.
+ self.usage_start = None # type: Optional[resource._RUsage]
self.main_thread = get_thread_id()
self.request = None
self.tag = ""
- self.alive = True
- self.scope = None
+ self.scope = None # type: Optional[_LogContextScope]
+
+ # keep track of whether we have hit the __exit__ block for this context
+ # (suggesting that the the thing that created the context thinks it should
+ # be finished, and that re-activating it would suggest an error).
+ self.finished = False
self.parent_context = parent_context
@@ -260,76 +283,83 @@ class LoggingContext(object):
# the request param overrides the request from the parent context
self.request = request
- def __str__(self):
+ def __str__(self) -> str:
if self.request:
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
- def current_context(cls):
+ def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
+ This exists for backwards compatibility. ``current_context()`` should be
+ called directly.
+
Returns:
LoggingContext: the current logging context
"""
- return getattr(cls.thread_local, "current_context", cls.sentinel)
+ warnings.warn(
+ "synapse.logging.context.LoggingContext.current_context() is deprecated "
+ "in favor of synapse.logging.context.current_context().",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return current_context()
@classmethod
- def set_current_context(cls, context):
+ def set_current_context(
+ cls, context: LoggingContextOrSentinel
+ ) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
+
+ This exists for backwards compatibility. ``set_current_context()`` should be
+ called directly.
+
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
- current = cls.current_context()
-
- if current is not context:
- current.stop()
- cls.thread_local.current_context = context
- context.start()
- return current
+ warnings.warn(
+ "synapse.logging.context.LoggingContext.set_current_context() is deprecated "
+ "in favor of synapse.logging.context.set_current_context().",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return set_current_context(context)
- def __enter__(self):
+ def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
- old_context = self.set_current_context(self)
+ old_context = set_current_context(self)
if self.previous_context != old_context:
- logger.warn(
+ logger.warning(
"Expected previous context %r, found %r",
self.previous_context,
old_context,
)
- self.alive = True
-
return self
- def __exit__(self, type, value, traceback):
+ def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
None to avoid suppressing any exceptions that were thrown.
"""
- current = self.set_current_context(self.previous_context)
+ current = set_current_context(self.previous_context)
if current is not self:
- if current is self.sentinel:
+ if current is SENTINEL_CONTEXT:
logger.warning("Expected logging context %s was lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
)
- self.previous_context = None
- self.alive = False
-
- # if we have a parent, pass our CPU usage stats on
- if self.parent_context is not None and hasattr(
- self.parent_context, "_resource_usage"
- ):
- self.parent_context._resource_usage += self._resource_usage
- # reset them in case we get entered again
- self._resource_usage.reset()
+ # the fact that we are here suggests that the caller thinks that everything
+ # is done and dusted for this logcontext, and further activity will not get
+ # recorded against the correct metrics.
+ self.finished = True
- def copy_to(self, record):
+ def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
@@ -340,44 +370,72 @@ class LoggingContext(object):
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record):
+ def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope
- def start(self):
+ def start(self, rusage: "Optional[resource._RUsage]") -> None:
+ """
+ Record that this logcontext is currently running.
+
+ This should not be called directly: use set_current_context
+
+ Args:
+ rusage: the resources used by the current thread, at the point of
+ switching to this logcontext. May be None if this platform doesn't
+ support getrusuage.
+ """
if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self)
return
+ if self.finished:
+ logger.warning("Re-starting finished log context %s", self)
+
# If we haven't already started record the thread resource usage so
# far
- if not self.usage_start:
- self.usage_start = get_thread_resource_usage()
+ if self.usage_start:
+ logger.warning("Re-starting already-active log context %s", self)
+ else:
+ self.usage_start = rusage
- def stop(self):
- if get_thread_id() != self.main_thread:
- logger.warning("Stopped logcontext %s on different thread", self)
- return
+ def stop(self, rusage: "Optional[resource._RUsage]") -> None:
+ """
+ Record that this logcontext is no longer running.
- # When we stop, let's record the cpu used since we started
- if not self.usage_start:
- # Log a warning on platforms that support thread usage tracking
- if is_thread_resource_usage_supported:
+ This should not be called directly: use set_current_context
+
+ Args:
+ rusage: the resources used by the current thread, at the point of
+ switching away from this logcontext. May be None if this platform
+ doesn't support getrusuage.
+ """
+
+ try:
+ if get_thread_id() != self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
+ return
+
+ if not rusage:
+ return
+
+ # Record the cpu used since we started
+ if not self.usage_start:
logger.warning(
- "Called stop on logcontext %s without calling start", self
+ "Called stop on logcontext %s without recording a start rusage",
+ self,
)
- return
+ return
- utime_delta, stime_delta = self._get_cputime()
- self._resource_usage.ru_utime += utime_delta
- self._resource_usage.ru_stime += stime_delta
-
- self.usage_start = None
+ utime_delta, stime_delta = self._get_cputime(rusage)
+ self.add_cputime(utime_delta, stime_delta)
+ finally:
+ self.usage_start = None
- def get_resource_usage(self):
+ def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
Returns:
@@ -390,19 +448,24 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we
# can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread
- if self.alive and self.usage_start and is_main_thread:
- utime_delta, stime_delta = self._get_cputime()
+ if self.usage_start and is_main_thread:
+ rusage = get_thread_resource_usage()
+ assert rusage is not None
+ utime_delta, stime_delta = self._get_cputime(rusage)
res.ru_utime += utime_delta
res.ru_stime += stime_delta
return res
- def _get_cputime(self):
- """Get the cpu usage time so far
+ def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
+ """Get the cpu usage time between start() and the given rusage
+
+ Args:
+ rusage: the current resource usage
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
"""
- current = get_thread_resource_usage()
+ assert self.usage_start is not None
utime_delta = current.ru_utime - self.usage_start.ru_utime
stime_delta = current.ru_stime - self.usage_start.ru_stime
@@ -426,30 +489,52 @@ class LoggingContext(object):
return utime_delta, stime_delta
- def add_database_transaction(self, duration_sec):
+ def add_cputime(self, utime_delta: float, stime_delta: float) -> None:
+ """Update the CPU time usage of this context (and any parents, recursively).
+
+ Args:
+ utime_delta: additional user time, in seconds, spent in this context.
+ stime_delta: additional system time, in seconds, spent in this context.
+ """
+ self._resource_usage.ru_utime += utime_delta
+ self._resource_usage.ru_stime += stime_delta
+ if self.parent_context:
+ self.parent_context.add_cputime(utime_delta, stime_delta)
+
+ def add_database_transaction(self, duration_sec: float) -> None:
+ """Record the use of a database transaction and the length of time it took.
+
+ Args:
+ duration_sec: The number of seconds the database transaction took.
+ """
if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec
+ if self.parent_context:
+ self.parent_context.add_database_transaction(duration_sec)
- def add_database_scheduled(self, sched_sec):
+ def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool
Args:
- sched_sec (float): number of seconds it took us to get a
- connection
+ sched_sec: number of seconds it took us to get a connection
"""
if sched_sec < 0:
raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec
+ if self.parent_context:
+ self.parent_context.add_database_scheduled(sched_sec)
- def record_event_fetch(self, event_count):
+ def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db
Args:
- event_count (int): number of events being fetched
+ event_count: number of events being fetched
"""
self._resource_usage.evt_db_fetch_count += event_count
+ if self.parent_context:
+ self.parent_context.record_event_fetch(event_count)
class LoggingContextFilter(logging.Filter):
@@ -460,15 +545,15 @@ class LoggingContextFilter(logging.Filter):
missing fields
"""
- def __init__(self, **defaults):
+ def __init__(self, **defaults) -> None:
self.defaults = defaults
- def filter(self, record):
+ def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
"""
- context = LoggingContext.current_context()
+ context = current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
@@ -488,26 +573,24 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
- def __init__(self, new_context=None):
- if new_context is None:
- new_context = LoggingContext.sentinel
+ def __init__(
+ self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+ ) -> None:
self.new_context = new_context
- def __enter__(self):
+ def __enter__(self) -> None:
"""Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(self.new_context)
+ self.current_context = set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
- if not self.current_context.alive:
- logger.debug("Entering dead context: %s", self.current_context)
- def __exit__(self, type, value, traceback):
+ def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
- context = LoggingContext.set_current_context(self.current_context)
+ context = set_current_context(self.current_context)
if context != self.new_context:
- if context is LoggingContext.sentinel:
+ if not context:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
@@ -516,12 +599,42 @@ class PreserveLoggingContext(object):
context,
)
- if self.current_context is not LoggingContext.sentinel:
- if not self.current_context.alive:
- logger.debug("Restoring dead context: %s", self.current_context)
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+ """Get the current logging context from thread local storage"""
+ return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
-def nested_logging_context(suffix, parent_context=None):
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+ """Set the current logging context in thread local storage
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ # everything blows up if we allow current_context to be set to None, so sanity-check
+ # that now.
+ if context is None:
+ raise TypeError("'context' argument may not be None")
+
+ current = current_context()
+
+ if current is not context:
+ rusage = get_thread_resource_usage()
+ current.stop(rusage)
+ _thread_local.current_context = context
+ context.start(rusage)
+
+ return current
+
+
+def nested_logging_context(
+ suffix: str, parent_context: Optional[LoggingContext] = None
+) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
@@ -542,10 +655,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns:
LoggingContext: new logging context.
"""
- if parent_context is None:
- parent_context = LoggingContext.current_context()
+ if parent_context is not None:
+ context = parent_context # type: LoggingContextOrSentinel
+ else:
+ context = current_context()
return LoggingContext(
- parent_context=parent_context, request=parent_context.request + "-" + suffix
+ parent_context=context, request=str(context.request) + "-" + suffix
)
@@ -567,12 +682,15 @@ def run_in_background(f, *args, **kwargs):
yield or await on (for instance because you want to pass it to
deferred.gatherResults()).
+ If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
+ the side effect of executing the coroutine).
+
Note that if you completely discard the result, you should make sure that
`f` doesn't raise any deferred exceptions, otherwise a scary-looking
CRITICAL error about an unhandled error will be logged without much
indication about where it came from.
"""
- current = LoggingContext.current_context()
+ current = current_context()
try:
res = f(*args, **kwargs)
except: # noqa: E722
@@ -593,7 +711,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so
# we need to restore it now.
- ctx = LoggingContext.set_current_context(current)
+ ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
@@ -612,7 +730,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +743,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
if not isinstance(deferred, defer.Deferred):
return deferred
@@ -634,14 +760,17 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it.
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ prev_context = set_current_context(SENTINEL_CONTEXT)
deferred.addBoth(_set_context_cb, prev_context)
return deferred
-def _set_context_cb(result, context):
+ResultT = TypeVar("ResultT")
+
+
+def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
- LoggingContext.set_current_context(context)
+ set_current_context(context)
return result
@@ -709,7 +838,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception.
"""
- logcontext = LoggingContext.current_context()
+ logcontext = current_context()
def g():
with LoggingContext(parent_context=logcontext):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 308a27213b..5dddf57008 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
import inspect
import logging
import re
+import types
from functools import wraps
+from typing import TYPE_CHECKING, Dict
from canonicaljson import json
@@ -177,6 +179,9 @@ from twisted.internet import defer
from synapse.config import ConfigError
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# Helper class
@@ -295,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
# Setup
-def init_tracer(config):
+def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer
-
- Args:
- config (HomeserverConfig): The config used by the homeserver
"""
global opentracing
- if not config.opentracer_enabled:
+ if not hs.config.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@@ -313,18 +315,15 @@ def init_tracer(config):
"installed."
)
- # Include the worker name
- name = config.worker_name if config.worker_name else "master"
-
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
- set_homeserver_whitelist(config.opentracer_whitelist)
+ set_homeserver_whitelist(hs.config.opentracer_whitelist)
JaegerConfig(
- config=config.jaeger_config,
- service_name="{} {}".format(config.server_name, name),
- scope_manager=LogContextScopeManager(config),
+ config=hs.config.jaeger_config,
+ service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+ scope_manager=LogContextScopeManager(hs.config),
).initialize_tracer()
@@ -547,7 +546,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -584,7 +583,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -639,7 +638,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@@ -653,7 +652,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
- carrier = {}
+ carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +776,7 @@ def trace_servlet(servlet_name, extract_context=False):
return func
@wraps(func)
- @defer.inlineCallbacks
- def _trace_servlet_inner(request, *args, **kwargs):
+ async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +793,14 @@ def trace_servlet(servlet_name, extract_context=False):
scope = start_active_span(servlet_name, tags=request_tags)
with scope:
- result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- return result
+ result = func(request, *args, **kwargs)
+
+ if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+ # Some servlets aren't async and just return results
+ # directly, so we handle that here.
+ return result
+
+ return await result
return _trace_servlet_inner
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 4eed4f2338..dc3ab00cbb 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted
-from synapse.logging.context import LoggingContext, nested_logging_context
+from synapse.logging.context import current_context, nested_logging_context
logger = logging.getLogger(__name__)
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not
available.
"""
- ctx = LoggingContext.current_context()
- if ctx is LoggingContext.sentinel:
- return None
- else:
- return ctx.scope
+ ctx = current_context()
+ return ctx.scope
def activate(self, span, finish_on_close):
"""
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
"""
enter_logcontext = False
- ctx = LoggingContext.current_context()
+ ctx = current_context()
- if ctx is LoggingContext.sentinel:
+ if not ctx:
# We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span)
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 7df0fa6087..99049bb5d8 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -20,8 +20,6 @@ import time
from functools import wraps
from inspect import getcallargs
-from six import PY3
-
_TIME_FUNC_ID = 0
@@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
- if PY3:
- lineno = f.__code__.co_firstlineno
- pathname = f.__code__.co_filename
- else:
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
+ lineno = f.__code__.co_firstlineno
+ pathname = f.__code__.co_filename
record = logging.LogRecord(
name=name,
@@ -119,7 +113,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
- s = inspect.currentframe().f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@@ -144,7 +142,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
- args=None,
+ args=(),
exc_info=None,
)
@@ -157,7 +155,12 @@ def trace_function(f):
def get_previous_frames():
- s = inspect.currentframe().f_back.f_back
+
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@@ -174,7 +177,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]):
- s = inspect.currentframe().f_back.f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+ s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
|