summary refs log tree commit diff
path: root/tests/config
diff options
context:
space:
mode:
Diffstat (limited to 'tests/config')
-rw-r--r--tests/config/test_cache.py171
-rw-r--r--tests/config/test_database.py22
-rw-r--r--tests/config/test_generate.py2
-rw-r--r--tests/config/test_load.py2
-rw-r--r--tests/config/test_tls.py54
5 files changed, 210 insertions, 41 deletions
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
new file mode 100644
index 0000000000..d3ec24c975
--- /dev/null
+++ b/tests/config/test_cache.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config, RootConfig
+from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.util.caches.lrucache import LruCache
+
+from tests.unittest import TestCase
+
+
+class FakeServer(Config):
+    section = "server"
+
+
+class TestConfig(RootConfig):
+    config_classes = [FakeServer, CacheConfig]
+
+
+class CacheConfigTests(TestCase):
+    def setUp(self):
+        # Reset caches before each test
+        TestConfig().caches.reset()
+
+    def test_individual_caches_from_environ(self):
+        """
+        Individual cache factors will be loaded from the environment.
+        """
+        config = {}
+        t = TestConfig()
+        t.caches._environ = {
+            "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+            "SYNAPSE_NOT_CACHE": "BLAH",
+        }
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+
+    def test_config_overrides_environ(self):
+        """
+        Individual cache factors defined in the environment will take precedence
+        over those in the config.
+        """
+        config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+        t = TestConfig()
+        t.caches._environ = {
+            "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+            "SYNAPSE_CACHE_FACTOR_FOO": 1,
+        }
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        self.assertEqual(
+            dict(t.caches.cache_factors),
+            {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
+        )
+
+    def test_individual_instantiated_before_config_load(self):
+        """
+        If a cache is instantiated before the config is read, it will be given
+        the default cache size in the interim, and then resized once the config
+        is loaded.
+        """
+        cache = LruCache(100)
+
+        add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+        self.assertEqual(cache.max_size, 50)
+
+        config = {"caches": {"per_cache_factors": {"foo": 3}}}
+        t = TestConfig()
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        self.assertEqual(cache.max_size, 300)
+
+    def test_individual_instantiated_after_config_load(self):
+        """
+        If a cache is instantiated after the config is read, it will be
+        immediately resized to the correct size given the per_cache_factor if
+        there is one.
+        """
+        config = {"caches": {"per_cache_factors": {"foo": 2}}}
+        t = TestConfig()
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        cache = LruCache(100)
+        add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+        self.assertEqual(cache.max_size, 200)
+
+    def test_global_instantiated_before_config_load(self):
+        """
+        If a cache is instantiated before the config is read, it will be given
+        the default cache size in the interim, and then resized to the new
+        default cache size once the config is loaded.
+        """
+        cache = LruCache(100)
+        add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+        self.assertEqual(cache.max_size, 50)
+
+        config = {"caches": {"global_factor": 4}}
+        t = TestConfig()
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        self.assertEqual(cache.max_size, 400)
+
+    def test_global_instantiated_after_config_load(self):
+        """
+        If a cache is instantiated after the config is read, it will be
+        immediately resized to the correct size given the global factor if there
+        is no per-cache factor.
+        """
+        config = {"caches": {"global_factor": 1.5}}
+        t = TestConfig()
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        cache = LruCache(100)
+        add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+        self.assertEqual(cache.max_size, 150)
+
+    def test_cache_with_asterisk_in_name(self):
+        """Some caches have asterisks in their name, test that they are set correctly.
+        """
+
+        config = {
+            "caches": {
+                "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
+            }
+        }
+        t = TestConfig()
+        t.caches._environ = {
+            "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
+            "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+        }
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        cache_a = LruCache(100)
+        add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
+        self.assertEqual(cache_a.max_size, 200)
+
+        cache_b = LruCache(100)
+        add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
+        self.assertEqual(cache_b.max_size, 300)
+
+        cache_c = LruCache(100)
+        add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
+        self.assertEqual(cache_c.max_size, 200)
+
+    def test_apply_cache_factor_from_config(self):
+        """Caches can disable applying cache factor updates, mainly used by
+        event cache size.
+        """
+
+        config = {"caches": {"event_cache_size": "10k"}}
+        t = TestConfig()
+        t.read_config(config, config_dir_path="", data_dir_path="")
+
+        cache = LruCache(
+            max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+        )
+        add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
+
+        self.assertEqual(cache.max_size, 10240)
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 151d3006ac..f675bde68e 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -21,9 +21,9 @@ from tests import unittest
 
 
 class DatabaseConfigTestCase(unittest.TestCase):
-    def test_database_configured_correctly_no_database_conf_param(self):
+    def test_database_configured_correctly(self):
         conf = yaml.safe_load(
-            DatabaseConfig().generate_config_section("/data_dir_path", None)
+            DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
         )
 
         expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
         }
 
         self.assertEqual(conf["database"], expected_database_conf)
-
-    def test_database_configured_correctly_database_conf_param(self):
-
-        database_conf = {
-            "name": "my super fast datastore",
-            "args": {
-                "user": "matrix",
-                "password": "synapse_database_password",
-                "host": "synapse_database_host",
-                "database": "matrix",
-            },
-        }
-
-        conf = yaml.safe_load(
-            DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
-        )
-
-        self.assertEqual(conf["database"], database_conf)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 2684e662de..463855ecc8 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
             )
 
         self.assertSetEqual(
-            set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
+            {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"},
             set(os.listdir(self.dir)),
         )
 
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index b3e557bd6a..734a9983e8 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
 
         with open(self.file, "r") as f:
             contents = f.readlines()
-        contents = [l for l in contents if needle not in l]
+        contents = [line for line in contents if needle not in line]
         with open(self.file, "w") as f:
             f.write("".join(contents))
 
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b02780772a..ec32d4b1ca 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -21,17 +21,24 @@ import yaml
 
 from OpenSSL import SSL
 
+from synapse.config._base import Config, RootConfig
 from synapse.config.tls import ConfigError, TlsConfig
-from synapse.crypto.context_factory import ClientTLSOptionsFactory
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
 
 from tests.unittest import TestCase
 
 
-class TestConfig(TlsConfig):
+class FakeServer(Config):
+    section = "server"
+
     def has_tls_listener(self):
         return False
 
 
+class TestConfig(RootConfig):
+    config_classes = [FakeServer, TlsConfig]
+
+
 class TLSConfigTests(TestCase):
     def test_warn_self_signed(self):
         """
@@ -173,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         t = TestConfig()
         t.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = ClientTLSOptionsFactory(t)
+        cf = FederationPolicyForHTTPS(t)
+        options = _get_ssl_context_options(cf._verify_ssl_context)
 
         # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
-        self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
-        self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
-        self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+        self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0)
+        self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
+        self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
 
     def test_tls_client_minimum_set_passed_through_1_0(self):
         """
@@ -188,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         t = TestConfig()
         t.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = ClientTLSOptionsFactory(t)
+        cf = FederationPolicyForHTTPS(t)
+        options = _get_ssl_context_options(cf._verify_ssl_context)
 
         # The context has not had any of the NO_TLS set.
-        self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
-        self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
-        self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+        self.assertEqual(options & SSL.OP_NO_TLSv1, 0)
+        self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
+        self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
 
     def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
         """
@@ -202,13 +211,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    None,  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain=None,  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
@@ -223,13 +232,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    "my_supe_secure_server",  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain="my_supe_secure_server",  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
@@ -266,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         t = TestConfig()
         t.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = ClientTLSOptionsFactory(t)
+        cf = FederationPolicyForHTTPS(t)
 
         # Not in the whitelist
         opts = cf.get_options(b"notexample.com")
@@ -275,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         # Caught by the wildcard
         opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
         self.assertFalse(opts._verifier._verify_certs)
+
+
+def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
+    """get the options bits from an openssl context object"""
+    # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
+    # use the low-level interface
+    return SSL._lib.SSL_CTX_get_options(ssl_context._context)