summary refs log tree commit diff
path: root/tests/config
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-10-12 10:47:13 +0100
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-10-12 10:47:13 +0100
commitaf85ac449d847fc4382d4381e3dae67c047b25b7 (patch)
tree31745dd8d49ed6f44d31e5eceaf0a45c92682164 /tests/config
parentRevert "Add a stub implementation of `StateFilter.approx_difference`" (diff)
parentAdd an approximate difference method to StateFilters (#10825) (diff)
downloadsynapse-af85ac449d847fc4382d4381e3dae67c047b25b7.tar.xz
Merge remote-tracking branch 'origin/develop' into rei/gsfg_1
to introduce `approx_difference`.
Diffstat (limited to 'tests/config')
-rw-r--r--tests/config/test_base.py21
-rw-r--r--tests/config/test_cache.py70
-rw-r--r--tests/config/test_load.py28
-rw-r--r--tests/config/test_ratelimiting.py2
-rw-r--r--tests/config/test_tls.py38
5 files changed, 79 insertions, 80 deletions
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index baa5313fb3..6a52f862f4 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -14,23 +14,28 @@
 
 import os.path
 import tempfile
+from unittest.mock import Mock
 
 from synapse.config import ConfigError
+from synapse.config._base import Config
 from synapse.util.stringutils import random_string
 
 from tests import unittest
 
 
-class BaseConfigTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
-        self.hs = hs
+class BaseConfigTestCase(unittest.TestCase):
+    def setUp(self):
+        # The root object needs a server property with a public_baseurl.
+        root = Mock()
+        root.server.public_baseurl = "http://test"
+        self.config = Config(root)
 
     def test_loading_missing_templates(self):
         # Use a temporary directory that exists on the system, but that isn't likely to
         # contain template files
         with tempfile.TemporaryDirectory() as tmp_dir:
             # Attempt to load an HTML template from our custom template directory
-            template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
+            template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
 
         # If no errors, we should've gotten the default template instead
 
@@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
 
                 # Attempt to load the template from our custom template directory
                 template = (
-                    self.hs.config.read_templates([template_filename], (tmp_dir,))
+                    self.config.read_templates([template_filename], (tmp_dir,))
                 )[0]
 
         # Render the template
@@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
 
         # Retrieve the template.
         template = (
-            self.hs.config.read_templates(
+            self.config.read_templates(
                 [template_filename],
                 (td.name for td in tempdirs),
             )
@@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
 
         # Retrieve the template.
         template = (
-            self.hs.config.read_templates(
+            self.config.read_templates(
                 [other_template_name],
                 (td.name for td in tempdirs),
             )
@@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
 
     def test_loading_template_from_nonexistent_custom_directory(self):
         with self.assertRaises(ConfigError):
-            self.hs.config.read_templates(
+            self.config.read_templates(
                 ["some_filename.html"], ("a_nonexistent_directory",)
             )
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 857d9cd096..79d417568d 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -12,59 +12,55 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.config._base import Config, RootConfig
+from unittest.mock import patch
+
 from synapse.config.cache import CacheConfig, add_resizable_cache
 from synapse.util.caches.lrucache import LruCache
 
 from tests.unittest import TestCase
 
 
-class FakeServer(Config):
-    section = "server"
-
-
-class TestConfig(RootConfig):
-    config_classes = [FakeServer, CacheConfig]
-
-
+# Patch the global _CACHES so that each test runs against its own state.
+@patch("synapse.config.cache._CACHES", new_callable=dict)
 class CacheConfigTests(TestCase):
     def setUp(self):
         # Reset caches before each test
-        TestConfig().caches.reset()
+        self.config = CacheConfig()
+
+    def tearDown(self):
+        self.config.reset()
 
-    def test_individual_caches_from_environ(self):
+    def test_individual_caches_from_environ(self, _caches):
         """
         Individual cache factors will be loaded from the environment.
         """
         config = {}
-        t = TestConfig()
-        t.caches._environ = {
+        self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
             "SYNAPSE_NOT_CACHE": "BLAH",
         }
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
-        self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+        self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
 
-    def test_config_overrides_environ(self):
+    def test_config_overrides_environ(self, _caches):
         """
         Individual cache factors defined in the environment will take precedence
         over those in the config.
         """
         config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
-        t = TestConfig()
-        t.caches._environ = {
+        self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
             "SYNAPSE_CACHE_FACTOR_FOO": 1,
         }
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         self.assertEqual(
-            dict(t.caches.cache_factors),
+            dict(self.config.cache_factors),
             {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
         )
 
-    def test_individual_instantiated_before_config_load(self):
+    def test_individual_instantiated_before_config_load(self, _caches):
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized once the config
@@ -76,26 +72,24 @@ class CacheConfigTests(TestCase):
         self.assertEqual(cache.max_size, 50)
 
         config = {"caches": {"per_cache_factors": {"foo": 3}}}
-        t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config)
 
         self.assertEqual(cache.max_size, 300)
 
-    def test_individual_instantiated_after_config_load(self):
+    def test_individual_instantiated_after_config_load(self, _caches):
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the per_cache_factor if
         there is one.
         """
         config = {"caches": {"per_cache_factors": {"foo": 2}}}
-        t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         cache = LruCache(100)
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 200)
 
-    def test_global_instantiated_before_config_load(self):
+    def test_global_instantiated_before_config_load(self, _caches):
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized to the new
@@ -106,26 +100,24 @@ class CacheConfigTests(TestCase):
         self.assertEqual(cache.max_size, 50)
 
         config = {"caches": {"global_factor": 4}}
-        t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         self.assertEqual(cache.max_size, 400)
 
-    def test_global_instantiated_after_config_load(self):
+    def test_global_instantiated_after_config_load(self, _caches):
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the global factor if there
         is no per-cache factor.
         """
         config = {"caches": {"global_factor": 1.5}}
-        t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         cache = LruCache(100)
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 150)
 
-    def test_cache_with_asterisk_in_name(self):
+    def test_cache_with_asterisk_in_name(self, _caches):
         """Some caches have asterisks in their name, test that they are set correctly."""
 
         config = {
@@ -133,12 +125,11 @@ class CacheConfigTests(TestCase):
                 "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
             }
         }
-        t = TestConfig()
-        t.caches._environ = {
+        self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
             "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
         }
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         cache_a = LruCache(100)
         add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
@@ -152,17 +143,16 @@ class CacheConfigTests(TestCase):
         add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
         self.assertEqual(cache_c.max_size, 200)
 
-    def test_apply_cache_factor_from_config(self):
+    def test_apply_cache_factor_from_config(self, _caches):
         """Caches can disable applying cache factor updates, mainly used by
         event cache size.
         """
 
         config = {"caches": {"event_cache_size": "10k"}}
-        t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        self.config.read_config(config, config_dir_path="", data_dir_path="")
 
         cache = LruCache(
-            max_size=t.caches.event_cache_size,
+            max_size=self.config.event_cache_size,
             apply_cache_factor_from_config=False,
         )
         add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 903c69127d..59635de205 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -49,24 +49,24 @@ class ConfigLoadingTestCase(unittest.TestCase):
 
         config = HomeServerConfig.load_config("", ["-c", self.file])
         self.assertTrue(
-            hasattr(config, "macaroon_secret_key"),
+            hasattr(config.key, "macaroon_secret_key"),
             "Want config to have attr macaroon_secret_key",
         )
-        if len(config.macaroon_secret_key) < 5:
+        if len(config.key.macaroon_secret_key) < 5:
             self.fail(
                 "Want macaroon secret key to be string of at least length 5,"
-                "was: %r" % (config.macaroon_secret_key,)
+                "was: %r" % (config.key.macaroon_secret_key,)
             )
 
         config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
         self.assertTrue(
-            hasattr(config, "macaroon_secret_key"),
+            hasattr(config.key, "macaroon_secret_key"),
             "Want config to have attr macaroon_secret_key",
         )
-        if len(config.macaroon_secret_key) < 5:
+        if len(config.key.macaroon_secret_key) < 5:
             self.fail(
                 "Want macaroon secret key to be string of at least length 5,"
-                "was: %r" % (config.macaroon_secret_key,)
+                "was: %r" % (config.key.macaroon_secret_key,)
             )
 
     def test_load_succeeds_if_macaroon_secret_key_missing(self):
@@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase):
         config1 = HomeServerConfig.load_config("", ["-c", self.file])
         config2 = HomeServerConfig.load_config("", ["-c", self.file])
         config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
-        self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
-        self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
+        self.assertEqual(
+            config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
+        )
+        self.assertEqual(
+            config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
+        )
 
     def test_disable_registration(self):
         self.generate_config()
@@ -84,16 +88,16 @@ class ConfigLoadingTestCase(unittest.TestCase):
         )
         # Check that disable_registration clobbers enable_registration.
         config = HomeServerConfig.load_config("", ["-c", self.file])
-        self.assertFalse(config.enable_registration)
+        self.assertFalse(config.registration.enable_registration)
 
         config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
-        self.assertFalse(config.enable_registration)
+        self.assertFalse(config.registration.enable_registration)
 
         # Check that either config value is clobbered by the command line.
         config = HomeServerConfig.load_or_generate_config(
             "", ["-c", self.file, "--enable-registration"]
         )
-        self.assertTrue(config.enable_registration)
+        self.assertTrue(config.registration.enable_registration)
 
     def test_stats_enabled(self):
         self.generate_config_and_remove_lines_containing("enable_metrics")
@@ -101,7 +105,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
 
         # The default Metrics Flags are off by default.
         config = HomeServerConfig.load_config("", ["-c", self.file])
-        self.assertFalse(config.metrics_flags.known_servers)
+        self.assertFalse(config.metrics.metrics_flags.known_servers)
 
     def generate_config(self):
         with redirect_stdout(StringIO()):
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index 3c7bb32e07..1b63e1adfd 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -30,7 +30,7 @@ class RatelimitConfigTestCase(TestCase):
 
         config = HomeServerConfig()
         config.parse_config_dict(config_dict, "", "")
-        config_obj = config.rc_federation
+        config_obj = config.ratelimiting.rc_federation
 
         self.assertEqual(config_obj.window_size, 20000)
         self.assertEqual(config_obj.sleep_limit, 693)
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b6bc1876b5..9ba5781573 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -42,9 +42,9 @@ class TLSConfigTests(TestCase):
         """
         config = {}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
-        self.assertEqual(t.federation_client_minimum_tls_version, "1")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
 
     def test_tls_client_minimum_set(self):
         """
@@ -52,29 +52,29 @@ class TLSConfigTests(TestCase):
         """
         config = {"federation_client_minimum_tls_version": 1}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
 
         config = {"federation_client_minimum_tls_version": 1.1}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1.1")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1")
 
         config = {"federation_client_minimum_tls_version": 1.2}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
 
         # Also test a string version
         config = {"federation_client_minimum_tls_version": "1"}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
 
         config = {"federation_client_minimum_tls_version": "1.2"}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1.2")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
 
     def test_tls_client_minimum_1_point_3_missing(self):
         """
@@ -91,7 +91,7 @@ class TLSConfigTests(TestCase):
         config = {"federation_client_minimum_tls_version": 1.3}
         t = TestConfig()
         with self.assertRaises(ConfigError) as e:
-            t.read_config(config, config_dir_path="", data_dir_path="")
+            t.tls.read_config(config, config_dir_path="", data_dir_path="")
         self.assertEqual(
             e.exception.args[0],
             (
@@ -112,8 +112,8 @@ class TLSConfigTests(TestCase):
 
         config = {"federation_client_minimum_tls_version": 1.3}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
-        self.assertEqual(t.federation_client_minimum_tls_version, "1.3")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
+        self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
 
     def test_tls_client_minimum_set_passed_through_1_2(self):
         """
@@ -121,7 +121,7 @@ class TLSConfigTests(TestCase):
         """
         config = {"federation_client_minimum_tls_version": 1.2}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
         cf = FederationPolicyForHTTPS(t)
         options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -137,7 +137,7 @@ class TLSConfigTests(TestCase):
         """
         config = {"federation_client_minimum_tls_version": 1}
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
         cf = FederationPolicyForHTTPS(t)
         options = _get_ssl_context_options(cf._verify_ssl_context)
@@ -159,7 +159,7 @@ class TLSConfigTests(TestCase):
         }
         t = TestConfig()
         e = self.assertRaises(
-            ConfigError, t.read_config, config, config_dir_path="", data_dir_path=""
+            ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path=""
         )
         self.assertIn("IDNA domain names", str(e))
 
@@ -174,7 +174,7 @@ class TLSConfigTests(TestCase):
             ]
         }
         t = TestConfig()
-        t.read_config(config, config_dir_path="", data_dir_path="")
+        t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
         cf = FederationPolicyForHTTPS(t)