summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2014-09-12 19:07:29 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2014-09-12 19:07:29 +0100
commit7a77aabb4bbb997db9dadd46e49d855946c1ae2e (patch)
tree761198b97180ca8e48d657c25da42e33c24d4225 /tests/unittest.py
parentAdd some docstrings (diff)
downloadsynapse-7a77aabb4bbb997db9dadd46e49d855946c1ae2e.tar.xz
Define a CLOS-like 'around' modifier as a decorator, to neaten up the 'orig_*' noise of wrapping the setUp()/tearDown() methods
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index e437d3541a..fb97fb1148 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -26,6 +26,23 @@ logging.getLogger().addHandler(logging.StreamHandler())
 logging.getLogger().setLevel(NEVER)
 
 
+def around(target):
+    """A CLOS-style 'around' modifier, which wraps the original method of the
+    given instance with another piece of code.
+
+    @around(self)
+    def method_name(orig, *args, **kwargs):
+        return orig(*args, **kwargs)
+    """
+    def _around(code):
+        name = code.__name__
+        orig = getattr(target, name)
+        def new(*args, **kwargs):
+            return code(orig, *args, **kwargs)
+        setattr(target, name, new)
+    return _around
+
+
 class TestCase(unittest.TestCase):
     """A subclass of twisted.trial's TestCase which looks for 'loglevel'
     attributes on both itself and its individual test methods, to override the
@@ -40,23 +57,19 @@ class TestCase(unittest.TestCase):
                     getattr(self, "loglevel",
                         NEVER))
 
-        orig_setUp = self.setUp
-
-        def setUp():
+        @around(self)
+        def setUp(orig):
             old_level = logging.getLogger().level
 
             if old_level != level:
-                orig_tearDown = self.tearDown
-
-                def tearDown():
-                    ret = orig_tearDown()
+                @around(self)
+                def tearDown(orig):
+                    ret = orig()
                     logging.getLogger().setLevel(old_level)
                     return ret
-                self.tearDown = tearDown
 
             logging.getLogger().setLevel(level)
-            return orig_setUp()
-        self.setUp = setUp
+            return orig()
 
 
 def DEBUG(target):