summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/logging/context.py50
-rw-r--r--synapse/storage/database.py8
-rw-r--r--synapse/util/metrics.py10
3 files changed, 46 insertions, 22 deletions
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index a507a83e93..c2db8b45f3 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -252,7 +252,12 @@ class LoggingContext:
         "scope",
     ]
 
-    def __init__(self, name=None, parent_context=None, request=None) -> None:
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        parent_context: "Optional[LoggingContext]" = None,
+        request: Optional[str] = None,
+    ) -> None:
         self.previous_context = current_context()
         self.name = name
 
@@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
     def __init__(self, request: str = ""):
         self._default_request = request
 
-    def filter(self, record) -> Literal[True]:
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
         context = current_context()
-        record.request = self._default_request
+        record.request = self._default_request  # type: ignore
 
         # context should never be None, but if it somehow ends up being, then
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
             # Logging is interested in the request.
-            record.request = context.request
+            record.request = context.request  # type: ignore
 
         return True
 
@@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     return current
 
 
-def nested_logging_context(
-    suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -632,20 +635,23 @@ def nested_logging_context(
             # ... do stuff
 
     Args:
-        suffix (str): suffix to add to the parent context's 'request'.
-        parent_context (LoggingContext|None): parent context. Will use the current context
-            if None.
+        suffix: suffix to add to the parent context's 'request'.
 
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is not None:
-        context = parent_context  # type: LoggingContextOrSentinel
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Starting nested logging context from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+        prefix = ""
     else:
-        context = current_context()
-    return LoggingContext(
-        parent_context=context, request=str(context.request) + "-" + suffix
-    )
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
+        prefix = str(parent_context.request)
+    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
 
 
 def preserve_fn(f):
@@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = current_context()
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+    else:
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
 
     def g():
-        with LoggingContext(parent_context=logcontext):
+        with LoggingContext(parent_context=parent_context):
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..b70ca3087b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
-    LoggingContextOrSentinel,
     current_context,
     make_deferred_yieldable,
 )
@@ -671,12 +670,15 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
-        if not parent_context:
+        curr_context = current_context()
+        if not curr_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
 
         start_time = monotonic_time()
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..24123d5cc4 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,15 @@ class Measure:
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        parent_context = current_context()
+        curr_context = current_context()
+        if not curr_context:
+            logger.warning(
+                "Starting metrics collection from sentinel context: metrics will be lost"
+            )
+            parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )