summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2014-09-12 19:07:29 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2014-09-12 19:07:29 +0100
commit7a77aabb4bbb997db9dadd46e49d855946c1ae2e (patch)
tree761198b97180ca8e48d657c25da42e33c24d4225 /tests/unittest.py
parentAdd some docstrings (diff)
downloadsynapse-7a77aabb4bbb997db9dadd46e49d855946c1ae2e.tar.xz
Define a CLOS-like 'around' modifier as a decorator, to neaten up the 'orig_*' noise of wrapping the setUp()/tearDown() methods
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/tests/unittest.py b/tests/unittest.py

index e437d3541a..fb97fb1148 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -26,6 +26,23 @@ logging.getLogger().addHandler(logging.StreamHandler()) logging.getLogger().setLevel(NEVER) +def around(target): + """A CLOS-style 'around' modifier, which wraps the original method of the + given instance with another piece of code. + + @around(self) + def method_name(orig, *args, **kwargs): + return orig(*args, **kwargs) + """ + def _around(code): + name = code.__name__ + orig = getattr(target, name) + def new(*args, **kwargs): + return code(orig, *args, **kwargs) + setattr(target, name, new) + return _around + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -40,23 +57,19 @@ class TestCase(unittest.TestCase): getattr(self, "loglevel", NEVER)) - orig_setUp = self.setUp - - def setUp(): + @around(self) + def setUp(orig): old_level = logging.getLogger().level if old_level != level: - orig_tearDown = self.tearDown - - def tearDown(): - ret = orig_tearDown() + @around(self) + def tearDown(orig): + ret = orig() logging.getLogger().setLevel(old_level) return ret - self.tearDown = tearDown logging.getLogger().setLevel(level) - return orig_setUp() - self.setUp = setUp + return orig() def DEBUG(target):