summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py239
1 files changed, 176 insertions, 63 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 31f6530978..30d1050a91 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@
 import argparse
 import errno
 import os
+from collections import OrderedDict
 from textwrap import dedent
+from typing import Any, MutableMapping, Optional
 
 from six import integer_types
 
@@ -51,7 +53,68 @@ Missing mandatory `server_name` config option.
 """
 
 
+CONFIG_FILE_HEADER = """\
+# Configuration file for Synapse.
+#
+# This is a YAML file: see [1] for a quick introduction. Note in particular
+# that *indentation is important*: all the elements of a list or dictionary
+# should have the same indentation.
+#
+# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html
+
+"""
+
+
+def path_exists(file_path):
+    """Check if a file exists
+
+    Unlike os.path.exists, this throws an exception if there is an error
+    checking if the file exists (for example, if there is a perms error on
+    the parent dir).
+
+    Returns:
+        bool: True if the file exists; False if not.
+    """
+    try:
+        os.stat(file_path)
+        return True
+    except OSError as e:
+        if e.errno != errno.ENOENT:
+            raise e
+        return False
+
+
 class Config(object):
+    """
+    A configuration section, containing configuration keys and values.
+
+    Attributes:
+        section (str): The section title of this config object, such as
+            "tls" or "logger". This is used to refer to it on the root
+            logger (for example, `config.tls.some_option`). Must be
+            defined in subclasses.
+    """
+
+    section = None
+
+    def __init__(self, root_config=None):
+        self.root = root_config
+
+    def __getattr__(self, item: str) -> Any:
+        """
+        Try and fetch a configuration option that does not exist on this class.
+
+        This is so that existing configs that rely on `self.value`, where value
+        is actually from a different config section, continue to work.
+        """
+        if item in ["generate_config_section", "read_config"]:
+            raise AttributeError(item)
+
+        if self.root is None:
+            raise AttributeError(item)
+        else:
+            return self.root._get_unclassed_config(self.section, item)
+
     @staticmethod
     def parse_size(value):
         if isinstance(value, integer_types):
@@ -88,22 +151,7 @@ class Config(object):
 
     @classmethod
     def path_exists(cls, file_path):
-        """Check if a file exists
-
-        Unlike os.path.exists, this throws an exception if there is an error
-        checking if the file exists (for example, if there is a perms error on
-        the parent dir).
-
-        Returns:
-            bool: True if the file exists; False if not.
-        """
-        try:
-            os.stat(file_path)
-            return True
-        except OSError as e:
-            if e.errno != errno.ENOENT:
-                raise e
-            return False
+        return path_exists(file_path)
 
     @classmethod
     def check_file(cls, file_path, config_name):
@@ -136,42 +184,106 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
-    def invoke_all(self, name, *args, **kargs):
-        """Invoke all instance methods with the given name and arguments in the
-        class's MRO.
+
+class RootConfig(object):
+    """
+    Holder of an application's configuration.
+
+    What configuration this object holds is defined by `config_classes`, a list
+    of Config classes that will be instantiated and given the contents of a
+    configuration file to read. They can then be accessed on this class by their
+    section name, defined in the Config or dynamically set to be the name of the
+    class, lower-cased and with "Config" removed.
+    """
+
+    config_classes = []
+
+    def __init__(self):
+        self._configs = OrderedDict()
+
+        for config_class in self.config_classes:
+            if config_class.section is None:
+                raise ValueError("%r requires a section name" % (config_class,))
+
+            try:
+                conf = config_class(self)
+            except Exception as e:
+                raise Exception("Failed making %s: %r" % (config_class.section, e))
+            self._configs[config_class.section] = conf
+
+    def __getattr__(self, item: str) -> Any:
+        """
+        Redirect lookups on this object either to config objects, or values on
+        config objects, so that `config.tls.blah` works, as well as legacy uses
+        of things like `config.server_name`. It will first look up the config
+        section name, and then values on those config classes.
+        """
+        if item in self._configs.keys():
+            return self._configs[item]
+
+        return self._get_unclassed_config(None, item)
+
+    def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+        """
+        Fetch a config value from one of the instantiated config classes that
+        has not been fetched directly.
 
         Args:
-            name (str): Name of function to invoke
+            asking_section: If this check is coming from a Config child, which
+                one? This section will not be asked if it has the value.
+            item: The configuration value key.
+
+        Raises:
+            AttributeError if no config classes have the config key. The body
+                will contain what sections were checked.
+        """
+        for key, val in self._configs.items():
+            if key == asking_section:
+                continue
+
+            if item in dir(val):
+                return getattr(val, item)
+
+        raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+    def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+        """
+        Invoke a function on all instantiated config objects this RootConfig is
+        configured to use.
+
+        Args:
+            func_name: Name of function to invoke
             *args
             **kwargs
-
         Returns:
-            list: The list of the return values from each method called
+            ordered dictionary of config section name and the result of the
+            function from it.
         """
-        results = []
-        for cls in type(self).mro():
-            if name in cls.__dict__:
-                results.append(getattr(cls, name)(self, *args, **kargs))
-        return results
+        res = OrderedDict()
+
+        for name, config in self._configs.items():
+            if hasattr(config, func_name):
+                res[name] = getattr(config, func_name)(*args, **kwargs)
+
+        return res
 
     @classmethod
-    def invoke_all_static(cls, name, *args, **kargs):
-        """Invoke all static methods with the given name and arguments in the
-        class's MRO.
+    def invoke_all_static(cls, func_name: str, *args, **kwargs):
+        """
+        Invoke a static function on config objects this RootConfig is
+        configured to use.
 
         Args:
-            name (str): Name of function to invoke
+            func_name: Name of function to invoke
             *args
             **kwargs
-
         Returns:
-            list: The list of the return values from each method called
+            ordered dictionary of config section name and the result of the
+            function from it.
         """
-        results = []
-        for c in cls.mro():
-            if name in c.__dict__:
-                results.append(getattr(c, name)(*args, **kargs))
-        return results
+        for config in cls.config_classes:
+            if hasattr(config, func_name):
+                getattr(config, func_name)(*args, **kwargs)
 
     def generate_config(
         self,
@@ -182,12 +294,12 @@ class Config(object):
         report_stats=None,
         open_private_ports=False,
         listeners=None,
-        database_conf=None,
         tls_certificate_path=None,
         tls_private_key_path=None,
         acme_domain=None,
     ):
-        """Build a default configuration file
+        """
+        Build a default configuration file
 
         This is used when the user explicitly asks us to generate a config file
         (eg with --generate_config).
@@ -242,7 +354,8 @@ class Config(object):
         Returns:
             str: the yaml config file
         """
-        return "\n\n".join(
+
+        return CONFIG_FILE_HEADER + "\n\n".join(
             dedent(conf)
             for conf in self.invoke_all(
                 "generate_config_section",
@@ -253,11 +366,10 @@ class Config(object):
                 report_stats=report_stats,
                 open_private_ports=open_private_ports,
                 listeners=listeners,
-                database_conf=database_conf,
                 tls_certificate_path=tls_certificate_path,
                 tls_private_key_path=tls_private_key_path,
                 acme_domain=acme_domain,
-            )
+            ).values()
         )
 
     @classmethod
@@ -356,8 +468,8 @@ class Config(object):
 
         Returns: Config object, or None if --generate-config or --generate-keys was set
         """
-        config_parser = argparse.ArgumentParser(add_help=False)
-        config_parser.add_argument(
+        parser = argparse.ArgumentParser(description=description)
+        parser.add_argument(
             "-c",
             "--config-path",
             action="append",
@@ -366,7 +478,7 @@ class Config(object):
             " may specify directories containing *.yaml files.",
         )
 
-        generate_group = config_parser.add_argument_group("Config generation")
+        generate_group = parser.add_argument_group("Config generation")
         generate_group.add_argument(
             "--generate-config",
             action="store_true",
@@ -414,12 +526,13 @@ class Config(object):
             ),
         )
 
-        config_args, remaining_args = config_parser.parse_known_args(argv)
+        cls.invoke_all_static("add_arguments", parser)
+        config_args = parser.parse_args(argv)
 
         config_files = find_config_files(search_paths=config_args.config_path)
 
         if not config_files:
-            config_parser.error(
+            parser.error(
                 "Must supply a config file.\nA config file can be automatically"
                 ' generated using "--generate-config -H SERVER_NAME'
                 ' -c CONFIG-FILE"'
@@ -438,13 +551,13 @@ class Config(object):
 
         if config_args.generate_config:
             if config_args.report_stats is None:
-                config_parser.error(
+                parser.error(
                     "Please specify either --report-stats=yes or --report-stats=no\n\n"
                     + MISSING_REPORT_STATS_SPIEL
                 )
 
             (config_path,) = config_files
-            if not cls.path_exists(config_path):
+            if not path_exists(config_path):
                 print("Generating config file %s" % (config_path,))
 
                 if config_args.data_directory:
@@ -469,11 +582,11 @@ class Config(object):
                     open_private_ports=config_args.open_private_ports,
                 )
 
-                if not cls.path_exists(config_dir_path):
+                if not path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "w") as config_file:
-                    config_file.write("# vim:ft=yaml\n\n")
                     config_file.write(config_str)
+                    config_file.write("\n\n# vim:ft=yaml")
 
                 config_dict = yaml.safe_load(config_str)
                 obj.generate_missing_files(config_dict, config_dir_path)
@@ -497,15 +610,6 @@ class Config(object):
                 )
                 generate_missing_configs = True
 
-        parser = argparse.ArgumentParser(
-            parents=[config_parser],
-            description=description,
-            formatter_class=argparse.RawDescriptionHelpFormatter,
-        )
-
-        obj.invoke_all_static("add_arguments", parser)
-        args = parser.parse_args(remaining_args)
-
         config_dict = read_config_files(config_files)
         if generate_missing_configs:
             obj.generate_missing_files(config_dict, config_dir_path)
@@ -514,11 +618,11 @@ class Config(object):
         obj.parse_config_dict(
             config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
         )
-        obj.invoke_all("read_arguments", args)
+        obj.invoke_all("read_arguments", config_args)
 
         return obj
 
-    def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+    def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
         """Read the information from the config dict into this Config object.
 
         Args:
@@ -553,6 +657,12 @@ def read_config_files(config_files):
     for config_file in config_files:
         with open(config_file) as file_stream:
             yaml_config = yaml.safe_load(file_stream)
+
+        if not isinstance(yaml_config, dict):
+            err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
+            print(err % (config_file,))
+            continue
+
         specified_config.update(yaml_config)
 
     if "server_name" not in specified_config:
@@ -607,3 +717,6 @@ def find_config_files(search_paths):
             else:
                 config_files.append(config_path)
     return config_files
+
+
+__all__ = ["Config", "RootConfig"]