summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/metrics.py23
-rw-r--r--tests/test_state.py4
2 files changed, 19 insertions, 8 deletions
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..e1f374807e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution(
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration"
+        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
     ]
 
     def __init__(self, clock, name):
@@ -58,14 +58,20 @@ class Measure(object):
         self.name = name
         self.start_context = None
         self.start = None
+        self.created_context = False
 
     def __enter__(self):
         self.start = self.clock.time_msec()
         self.start_context = LoggingContext.current_context()
-        if self.start_context:
-            self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
-            self.db_txn_count = self.start_context.db_txn_count
-            self.db_txn_duration = self.start_context.db_txn_duration
+        if not self.start_context:
+            logger.warn("Entered Measure without log context: %s", self.name)
+            self.start_context = LoggingContext("Measure")
+            self.start_context.__enter__()
+            self.created_context = True
+
+        self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+        self.db_txn_count = self.start_context.db_txn_count
+        self.db_txn_duration = self.start_context.db_txn_duration
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if exc_type is not None or not self.start_context:
@@ -91,7 +97,12 @@ class Measure(object):
 
         block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
         block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
-        block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+        block_db_txn_count.inc_by(
+            context.db_txn_count - self.db_txn_count, self.name
+        )
         block_db_txn_duration.inc_by(
             context.db_txn_duration - self.db_txn_duration, self.name
         )
+
+        if self.created_context:
+            self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/tests/test_state.py b/tests/test_state.py
index a1ea7ef672..1a11bbcee0 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -140,13 +140,13 @@ class StateTestCase(unittest.TestCase):
                 "add_event_hashes",
             ]
         )
-        hs = Mock(spec=[
+        hs = Mock(spec_set=[
             "get_datastore", "get_auth", "get_state_handler", "get_clock",
         ])
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
-        hs.get_auth.return_value = Auth(hs)
         hs.get_clock.return_value = MockClock()
+        hs.get_auth.return_value = Auth(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0