summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py55
1 files changed, 54 insertions, 1 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 3ebed6e612..cabe787cb4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -157,6 +157,21 @@ class HomeserverTestCase(TestCase):
     """
     A base TestCase that reduces boilerplate for HomeServer-using test cases.
 
+    Defines a setUp method which creates a mock reactor, and instantiates a homeserver
+    running on that reactor.
+
+    There are various hooks for modifying the way that the homeserver is instantiated:
+
+    * override make_homeserver, for example by making it pass different parameters into
+      setup_test_homeserver.
+
+    * override default_config, to return a modified configuration dictionary for use
+      by setup_test_homeserver.
+
+    * On a per-test basis, you can use the @override_config decorator to give a
+      dictionary containing additional configuration settings to be added to the basic
+      config dict.
+
     Attributes:
         servlets (list[function]): List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
@@ -168,6 +183,13 @@ class HomeserverTestCase(TestCase):
     hijack_auth = True
     needs_threadpool = False
 
+    def __init__(self, methodName, *args, **kwargs):
+        super().__init__(methodName, *args, **kwargs)
+
+        # see if we have any additional config for this test
+        method = getattr(self, methodName)
+        self._extra_config = getattr(method, "_extra_config", None)
+
     def setUp(self):
         """
         Set up the TestCase by calling the homeserver constructor, optionally
@@ -276,7 +298,14 @@ class HomeserverTestCase(TestCase):
         Args:
             name (str): The homeserver name/domain.
         """
-        return default_config(name)
+        config = default_config(name)
+
+        # apply any additional config which was specified via the override_config
+        # decorator.
+        if self._extra_config is not None:
+            config.update(self._extra_config)
+
+        return config
 
     def prepare(self, reactor, clock, homeserver):
         """
@@ -534,3 +563,27 @@ class HomeserverTestCase(TestCase):
         )
         self.render(request)
         self.assertEqual(channel.code, 403, channel.result)
+
+
+def override_config(extra_config):
+    """A decorator which can be applied to test functions to give additional HS config
+
+    For use
+
+    For example:
+
+        class MyTestCase(HomeserverTestCase):
+            @override_config({"enable_registration": False, ...})
+            def test_foo(self):
+                ...
+
+    Args:
+        extra_config(dict): Additional config settings to be merged into the default
+            config dict before instantiating the test homeserver.
+    """
+
+    def decorator(func):
+        func._extra_config = extra_config
+        return func
+
+    return decorator