diff --git a/tests/unittest.py b/tests/unittest.py
index 4c0fb029fd..561cebc223 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -151,6 +151,21 @@ class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
+ Defines a setUp method which creates a mock reactor, and instantiates a homeserver
+ running on that reactor.
+
+ There are various hooks for modifying the way that the homeserver is instantiated:
+
+ * override make_homeserver, for example by making it pass different parameters into
+ setup_test_homeserver.
+
+ * override default_config, to return a modified configuration dictionary for use
+ by setup_test_homeserver.
+
+ * On a per-test basis, you can use the @override_config decorator to give a
+ dictionary containing additional configuration settings to be added to the basic
+ config dict.
+
Attributes:
servlets (list[function]): List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
@@ -162,6 +177,13 @@ class HomeserverTestCase(TestCase):
hijack_auth = True
needs_threadpool = False
+ def __init__(self, methodName, *args, **kwargs):
+ super().__init__(methodName, *args, **kwargs)
+
+ # see if we have any additional config for this test
+ method = getattr(self, methodName)
+ self._extra_config = getattr(method, "_extra_config", None)
+
def setUp(self):
"""
Set up the TestCase by calling the homeserver constructor, optionally
@@ -270,7 +292,14 @@ class HomeserverTestCase(TestCase):
Args:
name (str): The homeserver name/domain.
"""
- return default_config(name)
+ config = default_config(name)
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
+
+ return config
def prepare(self, reactor, clock, homeserver):
"""
@@ -412,6 +441,7 @@ class HomeserverTestCase(TestCase):
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
+ self.assertEqual(channel.code, 200)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -437,7 +467,7 @@ class HomeserverTestCase(TestCase):
"POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
)
self.render(request)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
return user_id
@@ -528,3 +558,27 @@ class HomeserverTestCase(TestCase):
)
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+
+
+def override_config(extra_config):
+ """A decorator which can be applied to test functions to give additional HS config
+
+ For use
+
+ For example:
+
+ class MyTestCase(HomeserverTestCase):
+ @override_config({"enable_registration": False, ...})
+ def test_foo(self):
+ ...
+
+ Args:
+ extra_config(dict): Additional config settings to be merged into the default
+ config dict before instantiating the test homeserver.
+ """
+
+ def decorator(func):
+ func._extra_config = extra_config
+ return func
+
+ return decorator
|