summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-12-07 10:51:18 +0000
committerMark Haines <mark.haines@matrix.org>2015-12-07 10:51:18 +0000
commit3dd16308487d4b5f76d8b3f3e0bf5ce2a72aff22 (patch)
tree2499b6680ad4a39c73d6859157dacd8ce6d4e201 /synapse/util
parentMerge pull request #420 from matrix-org/markjh/resource_usage (diff)
downloadsynapse-3dd16308487d4b5f76d8b3f3e0bf5ce2a72aff22.tar.xz
Add a setter for the current log context.
Move the resource tracking inside that setter so that it is easier
to make sure that the resource tracking isn't double counting the
resource usage.
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/logcontext.py40
1 files changed, 24 insertions, 16 deletions
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index e4ce087afe..c20c89aa8f 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -87,13 +87,26 @@ class LoggingContext(object):
         """Get the current logging context from thread local storage"""
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
+    @classmethod
+    def set_current_context(cls, context):
+        """Set the current logging context in thread local storage
+        Args:
+            context(LoggingContext): The context to activate.
+        Returns:
+            The context that was previously active
+        """
+        current = cls.current_context()
+        if current is not context:
+            current.stop()
+            cls.thread_local.current_context = context
+            context.start()
+        return current
+
     def __enter__(self):
         """Enters this logging context into thread local storage"""
         if self.parent_context is not None:
             raise Exception("Attempt to enter logging context multiple times")
-        self.parent_context = self.current_context()
-        self.thread_local.current_context = self
-        self.start()
+        self.parent_context = self.set_current_context(self)
         return self
 
     def __exit__(self, type, value, traceback):
@@ -102,17 +115,16 @@ class LoggingContext(object):
         Returns:
             None to avoid suppressing any exeptions that were thrown.
         """
-        if self.thread_local.current_context is not self:
-            if self.thread_local.current_context is self.sentinel:
+        current = self.set_current_context(self.parent_context)
+        if current is not self:
+            if current is self.sentinel:
                 logger.debug("Expected logging context %s has been lost", self)
             else:
                 logger.warn(
                     "Current logging context %s is not expected context %s",
-                    self.thread_local.current_context,
+                    current,
                     self
                 )
-        self.thread_local.current_context = self.parent_context
-        self.stop()
         self.parent_context = None
 
     def __getattr__(self, name):
@@ -194,17 +206,13 @@ class PreserveLoggingContext(object):
 
     def __enter__(self):
         """Captures the current logging context"""
-        self.current_context = LoggingContext.current_context()
-        if self.new_context is not self.current_context:
-            self.current_context.stop()
-        LoggingContext.thread_local.current_context = self.new_context
+        self.current_context = LoggingContext.set_current_context(
+            self.new_context
+        )
 
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
-        context = LoggingContext.thread_local.current_context
-        LoggingContext.thread_local.current_context = self.current_context
-        if context is not self.current_context:
-            self.current_context.start()
+        LoggingContext.set_current_context(self.current_context)
         if self.current_context is not LoggingContext.sentinel:
             if self.current_context.parent_context is None:
                 logger.warn(