diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/async.py | 5 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 108 | ||||
-rw-r--r-- | synapse/util/logutils.py | 1 |
3 files changed, 113 insertions, 1 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py index 647ea6142c..3d3fbe182c 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,8 +16,11 @@ from twisted.internet import defer, reactor +from .logcontext import PreserveLoggingContext +@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) - return d + with PreserveLoggingContext(): + yield d diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py new file mode 100644 index 0000000000..13176b05ce --- /dev/null +++ b/synapse/util/logcontext.py @@ -0,0 +1,108 @@ +import threading +import logging + + +class LoggingContext(object): + """Additional context for log formatting. Contexts are scoped within a + "with" block. Contexts inherit the state of their parent contexts. + Args: + name (str): Name for the context for debugging. + """ + + __slots__ = ["parent_context", "name", "__dict__"] + + thread_local = threading.local() + + class Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = [] + + def copy_to(self, record): + pass + + sentinel = Sentinel() + + def __init__(self, name=None): + self.parent_context = None + self.name = name + + def __str__(self): + return "%s@%x" % (self.name, id(self)) + + @classmethod + def current_context(cls): + """Get the current logging context from thread local storage""" + return getattr(cls.thread_local, "current_context", cls.sentinel) + + def __enter__(self): + """Enters this logging context into thread local storage""" + if self.parent_context is not None: + raise Exception("Attempt to enter logging context multiple times") + self.parent_context = self.current_context() + self.thread_local.current_context = self + return self + + def __exit__(self, type, value, traceback): + """Restore the logging context in thread local storage to the state it + was before this context was entered. + Returns: + None to avoid suppressing any exeptions that were thrown. + """ + if self.thread_local.current_context is not self: + logging.error( + "Current logging context %s is not the expected context %s", + self.thread_local.current_context, + self + ) + self.thread_local.current_context = self.parent_context + self.parent_context = None + + def __getattr__(self, name): + """Delegate member lookup to parent context""" + return getattr(self.parent_context, name) + + def copy_to(self, record): + """Copy fields from this context and its parents to the record""" + if self.parent_context is not None: + self.parent_context.copy_to(record) + for key, value in self.__dict__.items(): + setattr(record, key, value) + + +class LoggingContextFilter(logging.Filter): + """Logging filter that adds values from the current logging context to each + record. + Args: + **defaults: Default values to avoid formatters complaining about + missing fields + """ + def __init__(self, **defaults): + self.defaults = defaults + + def filter(self, record): + """Add each fields from the logging contexts to the record. + Returns: + True to include the record in the log output. + """ + context = LoggingContext.current_context() + for key, value in self.defaults.items(): + setattr(record, key, value) + context.copy_to(record) + return True + + +class PreserveLoggingContext(object): + """Captures the current logging context and restores it when the scope is + exited. Used to restore the context after a function using + @defer.inlineCallbacks is resumed by a callback from the reactor.""" + + __slots__ = ["current_context"] + + def __enter__(self): + """Captures the current logging context""" + self.current_context = LoggingContext.current_context() + + def __exit__(self, type, value, traceback): + """Restores the current logging context""" + LoggingContext.thread_local.current_context = self.current_context diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index fadf0bd510..903a6cf1b3 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -75,6 +75,7 @@ def trace_function(f): linenum = f.func_code.co_firstlineno pathname = f.func_code.co_filename + @wraps(f) def wrapped(*args, **kwargs): name = f.__module__ logger = logging.getLogger(name) |