diff options
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r-- | synapse/config/_base.py | 239 |
1 files changed, 176 insertions, 63 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 31f6530978..30d1050a91 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,7 +18,9 @@ import argparse import errno import os +from collections import OrderedDict from textwrap import dedent +from typing import Any, MutableMapping, Optional from six import integer_types @@ -51,7 +53,68 @@ Missing mandatory `server_name` config option. """ +CONFIG_FILE_HEADER = """\ +# Configuration file for Synapse. +# +# This is a YAML file: see [1] for a quick introduction. Note in particular +# that *indentation is important*: all the elements of a list or dictionary +# should have the same indentation. +# +# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html + +""" + + +def path_exists(file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + class Config(object): + """ + A configuration section, containing configuration keys and values. + + Attributes: + section (str): The section title of this config object, such as + "tls" or "logger". This is used to refer to it on the root + logger (for example, `config.tls.some_option`). Must be + defined in subclasses. + """ + + section = None + + def __init__(self, root_config=None): + self.root = root_config + + def __getattr__(self, item: str) -> Any: + """ + Try and fetch a configuration option that does not exist on this class. + + This is so that existing configs that rely on `self.value`, where value + is actually from a different config section, continue to work. + """ + if item in ["generate_config_section", "read_config"]: + raise AttributeError(item) + + if self.root is None: + raise AttributeError(item) + else: + return self.root._get_unclassed_config(self.section, item) + @staticmethod def parse_size(value): if isinstance(value, integer_types): @@ -88,22 +151,7 @@ class Config(object): @classmethod def path_exists(cls, file_path): - """Check if a file exists - - Unlike os.path.exists, this throws an exception if there is an error - checking if the file exists (for example, if there is a perms error on - the parent dir). - - Returns: - bool: True if the file exists; False if not. - """ - try: - os.stat(file_path) - return True - except OSError as e: - if e.errno != errno.ENOENT: - raise e - return False + return path_exists(file_path) @classmethod def check_file(cls, file_path, config_name): @@ -136,42 +184,106 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() - def invoke_all(self, name, *args, **kargs): - """Invoke all instance methods with the given name and arguments in the - class's MRO. + +class RootConfig(object): + """ + Holder of an application's configuration. + + What configuration this object holds is defined by `config_classes`, a list + of Config classes that will be instantiated and given the contents of a + configuration file to read. They can then be accessed on this class by their + section name, defined in the Config or dynamically set to be the name of the + class, lower-cased and with "Config" removed. + """ + + config_classes = [] + + def __init__(self): + self._configs = OrderedDict() + + for config_class in self.config_classes: + if config_class.section is None: + raise ValueError("%r requires a section name" % (config_class,)) + + try: + conf = config_class(self) + except Exception as e: + raise Exception("Failed making %s: %r" % (config_class.section, e)) + self._configs[config_class.section] = conf + + def __getattr__(self, item: str) -> Any: + """ + Redirect lookups on this object either to config objects, or values on + config objects, so that `config.tls.blah` works, as well as legacy uses + of things like `config.server_name`. It will first look up the config + section name, and then values on those config classes. + """ + if item in self._configs.keys(): + return self._configs[item] + + return self._get_unclassed_config(None, item) + + def _get_unclassed_config(self, asking_section: Optional[str], item: str): + """ + Fetch a config value from one of the instantiated config classes that + has not been fetched directly. Args: - name (str): Name of function to invoke + asking_section: If this check is coming from a Config child, which + one? This section will not be asked if it has the value. + item: The configuration value key. + + Raises: + AttributeError if no config classes have the config key. The body + will contain what sections were checked. + """ + for key, val in self._configs.items(): + if key == asking_section: + continue + + if item in dir(val): + return getattr(val, item) + + raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + + def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + """ + Invoke a function on all instantiated config objects this RootConfig is + configured to use. + + Args: + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for cls in type(self).mro(): - if name in cls.__dict__: - results.append(getattr(cls, name)(self, *args, **kargs)) - return results + res = OrderedDict() + + for name, config in self._configs.items(): + if hasattr(config, func_name): + res[name] = getattr(config, func_name)(*args, **kwargs) + + return res @classmethod - def invoke_all_static(cls, name, *args, **kargs): - """Invoke all static methods with the given name and arguments in the - class's MRO. + def invoke_all_static(cls, func_name: str, *args, **kwargs): + """ + Invoke a static function on config objects this RootConfig is + configured to use. Args: - name (str): Name of function to invoke + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for c in cls.mro(): - if name in c.__dict__: - results.append(getattr(c, name)(*args, **kargs)) - return results + for config in cls.config_classes: + if hasattr(config, func_name): + getattr(config, func_name)(*args, **kwargs) def generate_config( self, @@ -182,12 +294,12 @@ class Config(object): report_stats=None, open_private_ports=False, listeners=None, - database_conf=None, tls_certificate_path=None, tls_private_key_path=None, acme_domain=None, ): - """Build a default configuration file + """ + Build a default configuration file This is used when the user explicitly asks us to generate a config file (eg with --generate_config). @@ -242,7 +354,8 @@ class Config(object): Returns: str: the yaml config file """ - return "\n\n".join( + + return CONFIG_FILE_HEADER + "\n\n".join( dedent(conf) for conf in self.invoke_all( "generate_config_section", @@ -253,11 +366,10 @@ class Config(object): report_stats=report_stats, open_private_ports=open_private_ports, listeners=listeners, - database_conf=database_conf, tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, - ) + ).values() ) @classmethod @@ -356,8 +468,8 @@ class Config(object): Returns: Config object, or None if --generate-config or --generate-keys was set """ - config_parser = argparse.ArgumentParser(add_help=False) - config_parser.add_argument( + parser = argparse.ArgumentParser(description=description) + parser.add_argument( "-c", "--config-path", action="append", @@ -366,7 +478,7 @@ class Config(object): " may specify directories containing *.yaml files.", ) - generate_group = config_parser.add_argument_group("Config generation") + generate_group = parser.add_argument_group("Config generation") generate_group.add_argument( "--generate-config", action="store_true", @@ -414,12 +526,13 @@ class Config(object): ), ) - config_args, remaining_args = config_parser.parse_known_args(argv) + cls.invoke_all_static("add_arguments", parser) + config_args = parser.parse_args(argv) config_files = find_config_files(search_paths=config_args.config_path) if not config_files: - config_parser.error( + parser.error( "Must supply a config file.\nA config file can be automatically" ' generated using "--generate-config -H SERVER_NAME' ' -c CONFIG-FILE"' @@ -438,13 +551,13 @@ class Config(object): if config_args.generate_config: if config_args.report_stats is None: - config_parser.error( + parser.error( "Please specify either --report-stats=yes or --report-stats=no\n\n" + MISSING_REPORT_STATS_SPIEL ) (config_path,) = config_files - if not cls.path_exists(config_path): + if not path_exists(config_path): print("Generating config file %s" % (config_path,)) if config_args.data_directory: @@ -469,11 +582,11 @@ class Config(object): open_private_ports=config_args.open_private_ports, ) - if not cls.path_exists(config_dir_path): + if not path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_file.write("# vim:ft=yaml\n\n") config_file.write(config_str) + config_file.write("\n\n# vim:ft=yaml") config_dict = yaml.safe_load(config_str) obj.generate_missing_files(config_dict, config_dir_path) @@ -497,15 +610,6 @@ class Config(object): ) generate_missing_configs = True - parser = argparse.ArgumentParser( - parents=[config_parser], - description=description, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - obj.invoke_all_static("add_arguments", parser) - args = parser.parse_args(remaining_args) - config_dict = read_config_files(config_files) if generate_missing_configs: obj.generate_missing_files(config_dict, config_dir_path) @@ -514,11 +618,11 @@ class Config(object): obj.parse_config_dict( config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) - obj.invoke_all("read_arguments", args) + obj.invoke_all("read_arguments", config_args) return obj - def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): + def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): """Read the information from the config dict into this Config object. Args: @@ -553,6 +657,12 @@ def read_config_files(config_files): for config_file in config_files: with open(config_file) as file_stream: yaml_config = yaml.safe_load(file_stream) + + if not isinstance(yaml_config, dict): + err = "File %r is empty or doesn't parse into a key-value map. IGNORING." + print(err % (config_file,)) + continue + specified_config.update(yaml_config) if "server_name" not in specified_config: @@ -607,3 +717,6 @@ def find_config_files(search_paths): else: config_files.append(config_path) return config_files + + +__all__ = ["Config", "RootConfig"] |