summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py106
1 files changed, 35 insertions, 71 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 21d110c82d..8757416a60 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -136,11 +136,6 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
-    @staticmethod
-    def read_config_file(file_path):
-        with open(file_path) as file_stream:
-            return yaml.safe_load(file_stream)
-
     def invoke_all(self, name, *args, **kargs):
         results = []
         for cls in type(self).mro():
@@ -158,9 +153,8 @@ class Config(object):
     ):
         """Build a default configuration file
 
-        This is used both when the user explicitly asks us to generate a config file
-        (eg with --generate_config), and before loading the config at runtime (to give
-        a base which the config files override)
+        This is used when the user explicitly asks us to generate a config file
+        (eg with --generate_config).
 
         Args:
             config_dir_path (str): The path where the config files are kept. Used to
@@ -182,10 +176,10 @@ class Config(object):
         Returns:
             str: the yaml config file
         """
-        default_config = "\n\n".join(
+        return "\n\n".join(
             dedent(conf)
             for conf in self.invoke_all(
-                "default_config",
+                "generate_config_section",
                 config_dir_path=config_dir_path,
                 data_dir_path=data_dir_path,
                 server_name=server_name,
@@ -194,8 +188,6 @@ class Config(object):
             )
         )
 
-        return default_config
-
     @classmethod
     def load_config(cls, description, argv):
         """Parse the commandline and config files
@@ -240,9 +232,7 @@ class Config(object):
         config_dir_path = os.path.abspath(config_dir_path)
         data_dir_path = os.getcwd()
 
-        config_dict = obj.read_config_files(
-            config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path
-        )
+        config_dict = read_config_files(config_files)
         obj.parse_config_dict(
             config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
         )
@@ -354,8 +344,8 @@ class Config(object):
                     config_file.write("# vim:ft=yaml\n\n")
                     config_file.write(config_str)
 
-                config = yaml.safe_load(config_str)
-                obj.invoke_all("generate_files", config)
+                config_dict = yaml.safe_load(config_str)
+                obj.generate_missing_files(config_dict, config_dir_path)
 
                 print(
                     (
@@ -385,12 +375,9 @@ class Config(object):
         obj.invoke_all("add_arguments", parser)
         args = parser.parse_args(remaining_args)
 
-        config_dict = obj.read_config_files(
-            config_files, config_dir_path=config_dir_path, data_dir_path=data_dir_path
-        )
-
+        config_dict = read_config_files(config_files)
         if generate_missing_configs:
-            obj.generate_missing_files(config_dict)
+            obj.generate_missing_files(config_dict, config_dir_path)
             return None
 
         obj.parse_config_dict(
@@ -400,53 +387,6 @@ class Config(object):
 
         return obj
 
-    def read_config_files(self, config_files, config_dir_path, data_dir_path):
-        """Read the config files into a dict
-
-        Args:
-            config_files (iterable[str]): A list of the config files to read
-
-            config_dir_path (str): The path where the config files are kept. Used to
-                create filenames for things like the log config and the signing key.
-
-            data_dir_path (str): The path where the data files are kept. Used to create
-                filenames for things like the database and media store.
-
-        Returns: dict
-        """
-        # first we read the config files into a dict
-        specified_config = {}
-        for config_file in config_files:
-            yaml_config = self.read_config_file(config_file)
-            specified_config.update(yaml_config)
-
-        # not all of the options have sensible defaults in code, so we now need to
-        # generate a default config file suitable for the specified server name...
-        if "server_name" not in specified_config:
-            raise ConfigError(MISSING_SERVER_NAME)
-        server_name = specified_config["server_name"]
-        config_string = self.generate_config(
-            config_dir_path=config_dir_path,
-            data_dir_path=data_dir_path,
-            server_name=server_name,
-            generate_secrets=False,
-        )
-
-        # ... and read it into a base config dict ...
-        config = yaml.safe_load(config_string)
-
-        # ... and finally, overlay it with the actual configuration.
-        config.pop("log_config")
-        config.update(specified_config)
-
-        if "report_stats" not in config:
-            raise ConfigError(
-                MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
-                + "\n"
-                + MISSING_REPORT_STATS_SPIEL
-            )
-        return config
-
     def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
         """Read the information from the config dict into this Config object.
 
@@ -466,8 +406,32 @@ class Config(object):
             data_dir_path=data_dir_path,
         )
 
-    def generate_missing_files(self, config_dict):
-        self.invoke_all("generate_files", config_dict)
+    def generate_missing_files(self, config_dict, config_dir_path):
+        self.invoke_all("generate_files", config_dict, config_dir_path)
+
+
+def read_config_files(config_files):
+    """Read the config files into a dict
+
+    Args:
+        config_files (iterable[str]): A list of the config files to read
+
+    Returns: dict
+    """
+    specified_config = {}
+    for config_file in config_files:
+        with open(config_file) as file_stream:
+            yaml_config = yaml.safe_load(file_stream)
+        specified_config.update(yaml_config)
+
+    if "server_name" not in specified_config:
+        raise ConfigError(MISSING_SERVER_NAME)
+
+    if "report_stats" not in specified_config:
+        raise ConfigError(
+            MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
+        )
+    return specified_config
 
 
 def find_config_files(search_paths):