summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8399.misc1
-rw-r--r--synapse/logging/context.py43
-rw-r--r--tests/crypto/test_keyring.py3
-rw-r--r--tests/unittest.py15
4 files changed, 41 insertions, 21 deletions
diff --git a/changelog.d/8399.misc b/changelog.d/8399.misc
new file mode 100644
index 0000000000..ce6e8123cf
--- /dev/null
+++ b/changelog.d/8399.misc
@@ -0,0 +1 @@
+Create a mechanism for marking tests "logcontext clean".
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2e282d9d67..ca0c774cc5 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -65,6 +65,11 @@ except Exception:
         return None
 
 
+# a hook which can be set during testing to assert that we aren't abusing logcontexts.
+def logcontext_error(msg: str):
+    logger.warning(msg)
+
+
 # get an id for the current thread.
 #
 # threading.get_ident doesn't actually return an OS-level tid, and annoyingly,
@@ -330,10 +335,9 @@ class LoggingContext:
         """Enters this logging context into thread local storage"""
         old_context = set_current_context(self)
         if self.previous_context != old_context:
-            logger.warning(
-                "Expected previous context %r, found %r",
-                self.previous_context,
-                old_context,
+            logcontext_error(
+                "Expected previous context %r, found %r"
+                % (self.previous_context, old_context,)
             )
         return self
 
@@ -346,10 +350,10 @@ class LoggingContext:
         current = set_current_context(self.previous_context)
         if current is not self:
             if current is SENTINEL_CONTEXT:
-                logger.warning("Expected logging context %s was lost", self)
+                logcontext_error("Expected logging context %s was lost" % (self,))
             else:
-                logger.warning(
-                    "Expected logging context %s but found %s", self, current
+                logcontext_error(
+                    "Expected logging context %s but found %s" % (self, current)
                 )
 
         # the fact that we are here suggests that the caller thinks that everything
@@ -387,16 +391,16 @@ class LoggingContext:
                 support getrusuage.
         """
         if get_thread_id() != self.main_thread:
-            logger.warning("Started logcontext %s on different thread", self)
+            logcontext_error("Started logcontext %s on different thread" % (self,))
             return
 
         if self.finished:
-            logger.warning("Re-starting finished log context %s", self)
+            logcontext_error("Re-starting finished log context %s" % (self,))
 
         # If we haven't already started record the thread resource usage so
         # far
         if self.usage_start:
-            logger.warning("Re-starting already-active log context %s", self)
+            logcontext_error("Re-starting already-active log context %s" % (self,))
         else:
             self.usage_start = rusage
 
@@ -414,7 +418,7 @@ class LoggingContext:
 
         try:
             if get_thread_id() != self.main_thread:
-                logger.warning("Stopped logcontext %s on different thread", self)
+                logcontext_error("Stopped logcontext %s on different thread" % (self,))
                 return
 
             if not rusage:
@@ -422,9 +426,9 @@ class LoggingContext:
 
             # Record the cpu used since we started
             if not self.usage_start:
-                logger.warning(
-                    "Called stop on logcontext %s without recording a start rusage",
-                    self,
+                logcontext_error(
+                    "Called stop on logcontext %s without recording a start rusage"
+                    % (self,)
                 )
                 return
 
@@ -584,14 +588,13 @@ class PreserveLoggingContext:
 
         if context != self._new_context:
             if not context:
-                logger.warning(
-                    "Expected logging context %s was lost", self._new_context
+                logcontext_error(
+                    "Expected logging context %s was lost" % (self._new_context,)
                 )
             else:
-                logger.warning(
-                    "Expected logging context %s but found %s",
-                    self._new_context,
-                    context,
+                logcontext_error(
+                    "Expected logging context %s but found %s"
+                    % (self._new_context, context,)
                 )
 
 
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 5cf408f21f..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
 from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
 
 
 class MockPerspectiveServer:
@@ -67,6 +68,7 @@ class MockPerspectiveServer:
         signedjson.sign.sign_json(res, self.server_name, self.key)
 
 
+@logcontext_clean
 class KeyringTestCase(unittest.HomeserverTestCase):
     def check_context(self, val, expected):
         self.assertEquals(getattr(current_context(), "request", None), expected)
@@ -309,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         mock_fetcher2.get_keys.assert_called_once()
 
 
+@logcontext_clean
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
diff --git a/tests/unittest.py b/tests/unittest.py
index dabf69cff4..bbe50c3851 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -23,7 +23,7 @@ import logging
 import time
 from typing import Optional, Tuple, Type, TypeVar, Union
 
-from mock import Mock
+from mock import Mock, patch
 
 from canonicaljson import json
 
@@ -169,6 +169,19 @@ def INFO(target):
     return target
 
 
+def logcontext_clean(target):
+    """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+    ... ie, any logcontext errors should cause a test failure
+    """
+
+    def logcontext_error(msg):
+        raise AssertionError("logcontext error: %s" % (msg))
+
+    patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+    return patcher(target)
+
+
 class HomeserverTestCase(TestCase):
     """
     A base TestCase that reduces boilerplate for HomeServer-using test cases.