diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 31f6530978..30d1050a91 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@
import argparse
import errno
import os
+from collections import OrderedDict
from textwrap import dedent
+from typing import Any, MutableMapping, Optional
from six import integer_types
@@ -51,7 +53,68 @@ Missing mandatory `server_name` config option.
"""
+CONFIG_FILE_HEADER = """\
+# Configuration file for Synapse.
+#
+# This is a YAML file: see [1] for a quick introduction. Note in particular
+# that *indentation is important*: all the elements of a list or dictionary
+# should have the same indentation.
+#
+# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html
+
+"""
+
+
+def path_exists(file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+
class Config(object):
+ """
+ A configuration section, containing configuration keys and values.
+
+ Attributes:
+ section (str): The section title of this config object, such as
+ "tls" or "logger". This is used to refer to it on the root
+ logger (for example, `config.tls.some_option`). Must be
+ defined in subclasses.
+ """
+
+ section = None
+
+ def __init__(self, root_config=None):
+ self.root = root_config
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Try and fetch a configuration option that does not exist on this class.
+
+ This is so that existing configs that rely on `self.value`, where value
+ is actually from a different config section, continue to work.
+ """
+ if item in ["generate_config_section", "read_config"]:
+ raise AttributeError(item)
+
+ if self.root is None:
+ raise AttributeError(item)
+ else:
+ return self.root._get_unclassed_config(self.section, item)
+
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@@ -88,22 +151,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
- """Check if a file exists
-
- Unlike os.path.exists, this throws an exception if there is an error
- checking if the file exists (for example, if there is a perms error on
- the parent dir).
-
- Returns:
- bool: True if the file exists; False if not.
- """
- try:
- os.stat(file_path)
- return True
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise e
- return False
+ return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@@ -136,42 +184,106 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
- def invoke_all(self, name, *args, **kargs):
- """Invoke all instance methods with the given name and arguments in the
- class's MRO.
+
+class RootConfig(object):
+ """
+ Holder of an application's configuration.
+
+ What configuration this object holds is defined by `config_classes`, a list
+ of Config classes that will be instantiated and given the contents of a
+ configuration file to read. They can then be accessed on this class by their
+ section name, defined in the Config or dynamically set to be the name of the
+ class, lower-cased and with "Config" removed.
+ """
+
+ config_classes = []
+
+ def __init__(self):
+ self._configs = OrderedDict()
+
+ for config_class in self.config_classes:
+ if config_class.section is None:
+ raise ValueError("%r requires a section name" % (config_class,))
+
+ try:
+ conf = config_class(self)
+ except Exception as e:
+ raise Exception("Failed making %s: %r" % (config_class.section, e))
+ self._configs[config_class.section] = conf
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Redirect lookups on this object either to config objects, or values on
+ config objects, so that `config.tls.blah` works, as well as legacy uses
+ of things like `config.server_name`. It will first look up the config
+ section name, and then values on those config classes.
+ """
+ if item in self._configs.keys():
+ return self._configs[item]
+
+ return self._get_unclassed_config(None, item)
+
+ def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+ """
+ Fetch a config value from one of the instantiated config classes that
+ has not been fetched directly.
Args:
- name (str): Name of function to invoke
+ asking_section: If this check is coming from a Config child, which
+ one? This section will not be asked if it has the value.
+ item: The configuration value key.
+
+ Raises:
+ AttributeError if no config classes have the config key. The body
+ will contain what sections were checked.
+ """
+ for key, val in self._configs.items():
+ if key == asking_section:
+ continue
+
+ if item in dir(val):
+ return getattr(val, item)
+
+ raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+ def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ """
+ Invoke a function on all instantiated config objects this RootConfig is
+ configured to use.
+
+ Args:
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for cls in type(self).mro():
- if name in cls.__dict__:
- results.append(getattr(cls, name)(self, *args, **kargs))
- return results
+ res = OrderedDict()
+
+ for name, config in self._configs.items():
+ if hasattr(config, func_name):
+ res[name] = getattr(config, func_name)(*args, **kwargs)
+
+ return res
@classmethod
- def invoke_all_static(cls, name, *args, **kargs):
- """Invoke all static methods with the given name and arguments in the
- class's MRO.
+ def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ """
+ Invoke a static function on config objects this RootConfig is
+ configured to use.
Args:
- name (str): Name of function to invoke
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for c in cls.mro():
- if name in c.__dict__:
- results.append(getattr(c, name)(*args, **kargs))
- return results
+ for config in cls.config_classes:
+ if hasattr(config, func_name):
+ getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@@ -182,12 +294,12 @@ class Config(object):
report_stats=None,
open_private_ports=False,
listeners=None,
- database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
):
- """Build a default configuration file
+ """
+ Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
@@ -242,7 +354,8 @@ class Config(object):
Returns:
str: the yaml config file
"""
- return "\n\n".join(
+
+ return CONFIG_FILE_HEADER + "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
"generate_config_section",
@@ -253,11 +366,10 @@ class Config(object):
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
- database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
- )
+ ).values()
)
@classmethod
@@ -356,8 +468,8 @@ class Config(object):
Returns: Config object, or None if --generate-config or --generate-keys was set
"""
- config_parser = argparse.ArgumentParser(add_help=False)
- config_parser.add_argument(
+ parser = argparse.ArgumentParser(description=description)
+ parser.add_argument(
"-c",
"--config-path",
action="append",
@@ -366,7 +478,7 @@ class Config(object):
" may specify directories containing *.yaml files.",
)
- generate_group = config_parser.add_argument_group("Config generation")
+ generate_group = parser.add_argument_group("Config generation")
generate_group.add_argument(
"--generate-config",
action="store_true",
@@ -414,12 +526,13 @@ class Config(object):
),
)
- config_args, remaining_args = config_parser.parse_known_args(argv)
+ cls.invoke_all_static("add_arguments", parser)
+ config_args = parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
if not config_files:
- config_parser.error(
+ parser.error(
"Must supply a config file.\nA config file can be automatically"
' generated using "--generate-config -H SERVER_NAME'
' -c CONFIG-FILE"'
@@ -438,13 +551,13 @@ class Config(object):
if config_args.generate_config:
if config_args.report_stats is None:
- config_parser.error(
+ parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n"
+ MISSING_REPORT_STATS_SPIEL
)
(config_path,) = config_files
- if not cls.path_exists(config_path):
+ if not path_exists(config_path):
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
@@ -469,11 +582,11 @@ class Config(object):
open_private_ports=config_args.open_private_ports,
)
- if not cls.path_exists(config_dir_path):
+ if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
- config_file.write("# vim:ft=yaml\n\n")
config_file.write(config_str)
+ config_file.write("\n\n# vim:ft=yaml")
config_dict = yaml.safe_load(config_str)
obj.generate_missing_files(config_dict, config_dir_path)
@@ -497,15 +610,6 @@ class Config(object):
)
generate_missing_configs = True
- parser = argparse.ArgumentParser(
- parents=[config_parser],
- description=description,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
-
- obj.invoke_all_static("add_arguments", parser)
- args = parser.parse_args(remaining_args)
-
config_dict = read_config_files(config_files)
if generate_missing_configs:
obj.generate_missing_files(config_dict, config_dir_path)
@@ -514,11 +618,11 @@ class Config(object):
obj.parse_config_dict(
config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
- obj.invoke_all("read_arguments", args)
+ obj.invoke_all("read_arguments", config_args)
return obj
- def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+ def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
@@ -553,6 +657,12 @@ def read_config_files(config_files):
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
+
+ if not isinstance(yaml_config, dict):
+ err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
+ print(err % (config_file,))
+ continue
+
specified_config.update(yaml_config)
if "server_name" not in specified_config:
@@ -607,3 +717,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
+
+
+__all__ = ["Config", "RootConfig"]
|