summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py191
1 files changed, 148 insertions, 43 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 31f6530978..08619404bb 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@
 import argparse
 import errno
 import os
+from collections import OrderedDict
 from textwrap import dedent
+from typing import Any, MutableMapping, Optional
 
 from six import integer_types
 
@@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
 """
 
 
+def path_exists(file_path):
+    """Check if a file exists
+
+    Unlike os.path.exists, this throws an exception if there is an error
+    checking if the file exists (for example, if there is a perms error on
+    the parent dir).
+
+    Returns:
+        bool: True if the file exists; False if not.
+    """
+    try:
+        os.stat(file_path)
+        return True
+    except OSError as e:
+        if e.errno != errno.ENOENT:
+            raise e
+        return False
+
+
 class Config(object):
+    """
+    A configuration section, containing configuration keys and values.
+
+    Attributes:
+        section (str): The section title of this config object, such as
+            "tls" or "logger". This is used to refer to it on the root
+            logger (for example, `config.tls.some_option`). Must be
+            defined in subclasses.
+    """
+
+    section = None
+
+    def __init__(self, root_config=None):
+        self.root = root_config
+
+    def __getattr__(self, item: str) -> Any:
+        """
+        Try and fetch a configuration option that does not exist on this class.
+
+        This is so that existing configs that rely on `self.value`, where value
+        is actually from a different config section, continue to work.
+        """
+        if item in ["generate_config_section", "read_config"]:
+            raise AttributeError(item)
+
+        if self.root is None:
+            raise AttributeError(item)
+        else:
+            return self.root._get_unclassed_config(self.section, item)
+
     @staticmethod
     def parse_size(value):
         if isinstance(value, integer_types):
@@ -88,22 +139,7 @@ class Config(object):
 
     @classmethod
     def path_exists(cls, file_path):
-        """Check if a file exists
-
-        Unlike os.path.exists, this throws an exception if there is an error
-        checking if the file exists (for example, if there is a perms error on
-        the parent dir).
-
-        Returns:
-            bool: True if the file exists; False if not.
-        """
-        try:
-            os.stat(file_path)
-            return True
-        except OSError as e:
-            if e.errno != errno.ENOENT:
-                raise e
-            return False
+        return path_exists(file_path)
 
     @classmethod
     def check_file(cls, file_path, config_name):
@@ -136,42 +172,106 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
-    def invoke_all(self, name, *args, **kargs):
-        """Invoke all instance methods with the given name and arguments in the
-        class's MRO.
+
+class RootConfig(object):
+    """
+    Holder of an application's configuration.
+
+    What configuration this object holds is defined by `config_classes`, a list
+    of Config classes that will be instantiated and given the contents of a
+    configuration file to read. They can then be accessed on this class by their
+    section name, defined in the Config or dynamically set to be the name of the
+    class, lower-cased and with "Config" removed.
+    """
+
+    config_classes = []
+
+    def __init__(self):
+        self._configs = OrderedDict()
+
+        for config_class in self.config_classes:
+            if config_class.section is None:
+                raise ValueError("%r requires a section name" % (config_class,))
+
+            try:
+                conf = config_class(self)
+            except Exception as e:
+                raise Exception("Failed making %s: %r" % (config_class.section, e))
+            self._configs[config_class.section] = conf
+
+    def __getattr__(self, item: str) -> Any:
+        """
+        Redirect lookups on this object either to config objects, or values on
+        config objects, so that `config.tls.blah` works, as well as legacy uses
+        of things like `config.server_name`. It will first look up the config
+        section name, and then values on those config classes.
+        """
+        if item in self._configs.keys():
+            return self._configs[item]
+
+        return self._get_unclassed_config(None, item)
+
+    def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+        """
+        Fetch a config value from one of the instantiated config classes that
+        has not been fetched directly.
+
+        Args:
+            asking_section: If this check is coming from a Config child, which
+                one? This section will not be asked if it has the value.
+            item: The configuration value key.
+
+        Raises:
+            AttributeError if no config classes have the config key. The body
+                will contain what sections were checked.
+        """
+        for key, val in self._configs.items():
+            if key == asking_section:
+                continue
+
+            if item in dir(val):
+                return getattr(val, item)
+
+        raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+    def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+        """
+        Invoke a function on all instantiated config objects this RootConfig is
+        configured to use.
 
         Args:
-            name (str): Name of function to invoke
+            func_name: Name of function to invoke
             *args
             **kwargs
-
         Returns:
-            list: The list of the return values from each method called
+            ordered dictionary of config section name and the result of the
+            function from it.
         """
-        results = []
-        for cls in type(self).mro():
-            if name in cls.__dict__:
-                results.append(getattr(cls, name)(self, *args, **kargs))
-        return results
+        res = OrderedDict()
+
+        for name, config in self._configs.items():
+            if hasattr(config, func_name):
+                res[name] = getattr(config, func_name)(*args, **kwargs)
+
+        return res
 
     @classmethod
-    def invoke_all_static(cls, name, *args, **kargs):
-        """Invoke all static methods with the given name and arguments in the
-        class's MRO.
+    def invoke_all_static(cls, func_name: str, *args, **kwargs):
+        """
+        Invoke a static function on config objects this RootConfig is
+        configured to use.
 
         Args:
-            name (str): Name of function to invoke
+            func_name: Name of function to invoke
             *args
             **kwargs
-
         Returns:
-            list: The list of the return values from each method called
+            ordered dictionary of config section name and the result of the
+            function from it.
         """
-        results = []
-        for c in cls.mro():
-            if name in c.__dict__:
-                results.append(getattr(c, name)(*args, **kargs))
-        return results
+        for config in cls.config_classes:
+            if hasattr(config, func_name):
+                getattr(config, func_name)(*args, **kwargs)
 
     def generate_config(
         self,
@@ -187,7 +287,8 @@ class Config(object):
         tls_private_key_path=None,
         acme_domain=None,
     ):
-        """Build a default configuration file
+        """
+        Build a default configuration file
 
         This is used when the user explicitly asks us to generate a config file
         (eg with --generate_config).
@@ -242,6 +343,7 @@ class Config(object):
         Returns:
             str: the yaml config file
         """
+
         return "\n\n".join(
             dedent(conf)
             for conf in self.invoke_all(
@@ -257,7 +359,7 @@ class Config(object):
                 tls_certificate_path=tls_certificate_path,
                 tls_private_key_path=tls_private_key_path,
                 acme_domain=acme_domain,
-            )
+            ).values()
         )
 
     @classmethod
@@ -444,7 +546,7 @@ class Config(object):
                 )
 
             (config_path,) = config_files
-            if not cls.path_exists(config_path):
+            if not path_exists(config_path):
                 print("Generating config file %s" % (config_path,))
 
                 if config_args.data_directory:
@@ -469,7 +571,7 @@ class Config(object):
                     open_private_ports=config_args.open_private_ports,
                 )
 
-                if not cls.path_exists(config_dir_path):
+                if not path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "w") as config_file:
                     config_file.write("# vim:ft=yaml\n\n")
@@ -518,7 +620,7 @@ class Config(object):
 
         return obj
 
-    def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+    def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
         """Read the information from the config dict into this Config object.
 
         Args:
@@ -607,3 +709,6 @@ def find_config_files(search_paths):
             else:
                 config_files.append(config_path)
     return config_files
+
+
+__all__ = ["Config", "RootConfig"]