diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index e4ce087afe..c20c89aa8f 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -87,13 +87,26 @@ class LoggingContext(object):
"""Get the current logging context from thread local storage"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
+ @classmethod
+ def set_current_context(cls, context):
+ """Set the current logging context in thread local storage
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ current = cls.current_context()
+ if current is not context:
+ current.stop()
+ cls.thread_local.current_context = context
+ context.start()
+ return current
+
def __enter__(self):
"""Enters this logging context into thread local storage"""
if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times")
- self.parent_context = self.current_context()
- self.thread_local.current_context = self
- self.start()
+ self.parent_context = self.set_current_context(self)
return self
def __exit__(self, type, value, traceback):
@@ -102,17 +115,16 @@ class LoggingContext(object):
Returns:
None to avoid suppressing any exeptions that were thrown.
"""
- if self.thread_local.current_context is not self:
- if self.thread_local.current_context is self.sentinel:
+ current = self.set_current_context(self.parent_context)
+ if current is not self:
+ if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
- self.thread_local.current_context,
+ current,
self
)
- self.thread_local.current_context = self.parent_context
- self.stop()
self.parent_context = None
def __getattr__(self, name):
@@ -194,17 +206,13 @@ class PreserveLoggingContext(object):
def __enter__(self):
"""Captures the current logging context"""
- self.current_context = LoggingContext.current_context()
- if self.new_context is not self.current_context:
- self.current_context.stop()
- LoggingContext.thread_local.current_context = self.new_context
+ self.current_context = LoggingContext.set_current_context(
+ self.new_context
+ )
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
- context = LoggingContext.thread_local.current_context
- LoggingContext.thread_local.current_context = self.current_context
- if context is not self.current_context:
- self.current_context.start()
+ LoggingContext.set_current_context(self.current_context)
if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None:
logger.warn(
|