diff options
-rw-r--r-- | tests/unittest.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index e437d3541a..fb97fb1148 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -26,6 +26,23 @@ logging.getLogger().addHandler(logging.StreamHandler()) logging.getLogger().setLevel(NEVER) +def around(target): + """A CLOS-style 'around' modifier, which wraps the original method of the + given instance with another piece of code. + + @around(self) + def method_name(orig, *args, **kwargs): + return orig(*args, **kwargs) + """ + def _around(code): + name = code.__name__ + orig = getattr(target, name) + def new(*args, **kwargs): + return code(orig, *args, **kwargs) + setattr(target, name, new) + return _around + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -40,23 +57,19 @@ class TestCase(unittest.TestCase): getattr(self, "loglevel", NEVER)) - orig_setUp = self.setUp - - def setUp(): + @around(self) + def setUp(orig): old_level = logging.getLogger().level if old_level != level: - orig_tearDown = self.tearDown - - def tearDown(): - ret = orig_tearDown() + @around(self) + def tearDown(orig): + ret = orig() logging.getLogger().setLevel(old_level) return ret - self.tearDown = tearDown logging.getLogger().setLevel(level) - return orig_setUp() - self.setUp = setUp + return orig() def DEBUG(target): |