diff --git a/tests/unittest.py b/tests/unittest.py
index 38715972dd..b15b06726b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,22 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+import twisted
+import twisted.logger
from twisted.trial import unittest
-import logging
+from synapse.util.logcontext import LoggingContextFilter
+
+# Set up putting Synapse's logs into Trial's.
+rootLogger = logging.getLogger()
+
+log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
+)
+
-# logging doesn't have a "don't log anything at all EVARRRR setting,
-# but since the highest value is 50, 1000000 should do ;)
-NEVER = 1000000
+class ToTwistedHandler(logging.Handler):
+ tx_log = twisted.logger.Logger()
-handler = logging.StreamHandler()
-handler.setFormatter(logging.Formatter(
- "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
-))
-logging.getLogger().addHandler(handler)
-logging.getLogger().setLevel(NEVER)
-logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
-logging.getLogger("synapse.storage.txn").setLevel(NEVER)
+ def emit(self, record):
+ log_entry = self.format(record)
+ log_level = record.levelname.lower().replace('warning', 'warn')
+ self.tx_log.emit(
+ twisted.logger.LogLevel.levelWithName(log_level),
+ log_entry.replace("{", r"(").replace("}", r")"),
+ )
+
+
+handler = ToTwistedHandler()
+formatter = logging.Formatter(log_format)
+handler.setFormatter(formatter)
+handler.addFilter(LoggingContextFilter(request=""))
+rootLogger.addHandler(handler)
def around(target):
@@ -61,10 +78,14 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
@around(self)
def setUp(orig):
+ # enable debugging of delayed calls - this means that we get a
+ # traceback when a unit test exits leaving things on the reactor.
+ twisted.internet.base.DelayedCall.debug = True
+
old_level = logging.getLogger().level
if old_level != level:
@@ -88,6 +109,17 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key)
+ def assert_dict(self, required, actual):
+ """Does a partial assert of a dict.
+
+ Args:
+ required (dict): The keys and value which MUST be in 'actual'.
+ actual (dict): The test result. Extra keys will not be checked.
+ """
+ for key in required:
+ self.assertEquals(required[key], actual[key],
+ msg="%s mismatch. %s" % (key, actual))
+
def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|