summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py3
-rw-r--r--synapse/util/debug.py3
-rw-r--r--synapse/util/logcontext.py84
3 files changed, 80 insertions, 10 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index d69c7cb991..2170746025 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -64,8 +64,7 @@ class Clock(object):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback(*args, **kwargs):
-            with PreserveLoggingContext():
-                LoggingContext.thread_local.current_context = current_context
+            with PreserveLoggingContext(current_context):
                 callback(*args, **kwargs)
 
         with PreserveLoggingContext():
diff --git a/synapse/util/debug.py b/synapse/util/debug.py
index f6a5a841a4..b2bee7958f 100644
--- a/synapse/util/debug.py
+++ b/synapse/util/debug.py
@@ -30,8 +30,7 @@ def debug_deferreds():
         context = LoggingContext.current_context()
 
         def restore_context_callback(x):
-            with PreserveLoggingContext():
-                LoggingContext.thread_local.current_context = context
+            with PreserveLoggingContext(context):
                 return fn(x)
 
         return restore_context_callback
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 7e6062c1b8..e4ce087afe 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -19,6 +19,25 @@ import logging
 
 logger = logging.getLogger(__name__)
 
+try:
+    import resource
+
+    # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
+    # to be 1 on linux so we hard code it.
+    RUSAGE_THREAD = 1
+
+    # If the system doesn't support RUSAGE_THREAD then this should throw an
+    # exception.
+    resource.getrusage(RUSAGE_THREAD)
+
+    def get_thread_resource_usage():
+        return resource.getrusage(RUSAGE_THREAD)
+except:
+    # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
+    # won't track resource usage by returning None.
+    def get_thread_resource_usage():
+        return None
+
 
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
@@ -27,7 +46,9 @@ class LoggingContext(object):
         name (str): Name for the context for debugging.
     """
 
-    __slots__ = ["parent_context", "name", "__dict__"]
+    __slots__ = [
+        "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
+    ]
 
     thread_local = threading.local()
 
@@ -42,11 +63,21 @@ class LoggingContext(object):
         def copy_to(self, record):
             pass
 
+        def start(self):
+            pass
+
+        def stop(self):
+            pass
+
     sentinel = Sentinel()
 
     def __init__(self, name=None):
         self.parent_context = None
         self.name = name
+        self.ru_stime = 0.
+        self.ru_utime = 0.
+        self.usage_start = None
+        self.main_thread = threading.current_thread()
 
     def __str__(self):
         return "%s@%x" % (self.name, id(self))
@@ -62,6 +93,7 @@ class LoggingContext(object):
             raise Exception("Attempt to enter logging context multiple times")
         self.parent_context = self.current_context()
         self.thread_local.current_context = self
+        self.start()
         return self
 
     def __exit__(self, type, value, traceback):
@@ -80,6 +112,7 @@ class LoggingContext(object):
                     self
                 )
         self.thread_local.current_context = self.parent_context
+        self.stop()
         self.parent_context = None
 
     def __getattr__(self, name):
@@ -93,6 +126,39 @@ class LoggingContext(object):
         for key, value in self.__dict__.items():
             setattr(record, key, value)
 
+        record.ru_utime, record.ru_stime = self.get_resource_usage()
+
+    def start(self):
+        if threading.current_thread() is not self.main_thread:
+            return
+
+        if self.usage_start and self.usage_end:
+            self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
+            self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
+            self.usage_start = None
+            self.usage_end = None
+
+        if not self.usage_start:
+            self.usage_start = get_thread_resource_usage()
+
+    def stop(self):
+        if threading.current_thread() is not self.main_thread:
+            return
+
+        if self.usage_start:
+            self.usage_end = get_thread_resource_usage()
+
+    def get_resource_usage(self):
+        ru_utime = self.ru_utime
+        ru_stime = self.ru_stime
+
+        if self.usage_start and threading.current_thread() is self.main_thread:
+            current = get_thread_resource_usage()
+            ru_utime += current.ru_utime - self.usage_start.ru_utime
+            ru_stime += current.ru_stime - self.usage_start.ru_stime
+
+        return ru_utime, ru_stime
+
 
 class LoggingContextFilter(logging.Filter):
     """Logging filter that adds values from the current logging context to each
@@ -121,17 +187,24 @@ class PreserveLoggingContext(object):
     exited. Used to restore the context after a function using
     @defer.inlineCallbacks is resumed by a callback from the reactor."""
 
-    __slots__ = ["current_context"]
+    __slots__ = ["current_context", "new_context"]
+
+    def __init__(self, new_context=LoggingContext.sentinel):
+        self.new_context = new_context
 
     def __enter__(self):
         """Captures the current logging context"""
         self.current_context = LoggingContext.current_context()
-        LoggingContext.thread_local.current_context = LoggingContext.sentinel
+        if self.new_context is not self.current_context:
+            self.current_context.stop()
+        LoggingContext.thread_local.current_context = self.new_context
 
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
+        context = LoggingContext.thread_local.current_context
         LoggingContext.thread_local.current_context = self.current_context
-
+        if context is not self.current_context:
+            self.current_context.start()
         if self.current_context is not LoggingContext.sentinel:
             if self.current_context.parent_context is None:
                 logger.warn(
@@ -164,8 +237,7 @@ class _PreservingContextDeferred(defer.Deferred):
 
     def _wrap_callback(self, f):
         def g(res, *args, **kwargs):
-            with PreserveLoggingContext():
-                LoggingContext.thread_local.current_context = self._log_context
+            with PreserveLoggingContext(self._log_context):
                 res = f(res, *args, **kwargs)
             return res
         return g