summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py242
1 files changed, 147 insertions, 95 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index f7d7f153bb..965478d8d5 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -134,11 +136,6 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
-    @staticmethod
-    def read_config_file(file_path):
-        with open(file_path) as file_stream:
-            return yaml.safe_load(file_stream)
-
     def invoke_all(self, name, *args, **kargs):
         results = []
         for cls in type(self).mro():
@@ -153,12 +150,12 @@ class Config(object):
         server_name,
         generate_secrets=False,
         report_stats=None,
+        open_private_ports=False,
     ):
         """Build a default configuration file
 
-        This is used both when the user explicitly asks us to generate a config file
-        (eg with --generate_config), and before loading the config at runtime (to give
-        a base which the config files override)
+        This is used when the user explicitly asks us to generate a config file
+        (eg with --generate_config).
 
         Args:
             config_dir_path (str): The path where the config files are kept. Used to
@@ -177,25 +174,33 @@ class Config(object):
             report_stats (bool|None): Initial setting for the report_stats setting.
                 If None, report_stats will be left unset.
 
+            open_private_ports (bool): True to leave private ports (such as the non-TLS
+                HTTP listener) open to the internet.
+
         Returns:
             str: the yaml config file
         """
-        default_config = "\n\n".join(
+        return "\n\n".join(
             dedent(conf)
             for conf in self.invoke_all(
-                "default_config",
+                "generate_config_section",
                 config_dir_path=config_dir_path,
                 data_dir_path=data_dir_path,
                 server_name=server_name,
                 generate_secrets=generate_secrets,
                 report_stats=report_stats,
+                open_private_ports=open_private_ports,
             )
         )
 
-        return default_config
-
     @classmethod
     def load_config(cls, description, argv):
+        """Parse the commandline and config files
+
+        Doesn't support config-file-generation: used by the worker apps.
+
+        Returns: Config object.
+        """
         config_parser = argparse.ArgumentParser(description=description)
         config_parser.add_argument(
             "-c",
@@ -210,7 +215,7 @@ class Config(object):
             "--keys-directory",
             metavar="DIRECTORY",
             help="Where files such as certs and signing keys are stored when"
-            " their location is given explicitly in the config."
+            " their location is not given explicitly in the config."
             " Defaults to the directory containing the last config file",
         )
 
@@ -222,8 +227,19 @@ class Config(object):
 
         config_files = find_config_files(search_paths=config_args.config_path)
 
-        obj.read_config_files(
-            config_files, keys_directory=config_args.keys_directory, generate_keys=False
+        if not config_files:
+            config_parser.error("Must supply a config file.")
+
+        if config_args.keys_directory:
+            config_dir_path = config_args.keys_directory
+        else:
+            config_dir_path = os.path.dirname(config_files[-1])
+        config_dir_path = os.path.abspath(config_dir_path)
+        data_dir_path = os.getcwd()
+
+        config_dict = read_config_files(config_files)
+        obj.parse_config_dict(
+            config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
         )
 
         obj.invoke_all("read_arguments", config_args)
@@ -232,6 +248,12 @@ class Config(object):
 
     @classmethod
     def load_or_generate_config(cls, description, argv):
+        """Parse the commandline and config files
+
+        Supports generation of config files, so is used for the main homeserver app.
+
+        Returns: Config object, or None if --generate-config or --generate-keys was set
+        """
         config_parser = argparse.ArgumentParser(add_help=False)
         config_parser.add_argument(
             "-c",
@@ -241,37 +263,74 @@ class Config(object):
             help="Specify config file. Can be given multiple times and"
             " may specify directories containing *.yaml files.",
         )
-        config_parser.add_argument(
+
+        generate_group = config_parser.add_argument_group("Config generation")
+        generate_group.add_argument(
             "--generate-config",
             action="store_true",
-            help="Generate a config file for the server name",
+            help="Generate a config file, then exit.",
         )
-        config_parser.add_argument(
+        generate_group.add_argument(
+            "--generate-missing-configs",
+            "--generate-keys",
+            action="store_true",
+            help="Generate any missing additional config files, then exit.",
+        )
+        generate_group.add_argument(
+            "-H", "--server-name", help="The server name to generate a config file for."
+        )
+        generate_group.add_argument(
             "--report-stats",
             action="store",
-            help="Whether the generated config reports anonymized usage statistics",
+            help="Whether the generated config reports anonymized usage statistics.",
             choices=["yes", "no"],
         )
-        config_parser.add_argument(
-            "--generate-keys",
-            action="store_true",
-            help="Generate any missing key files then exit",
-        )
-        config_parser.add_argument(
+        generate_group.add_argument(
+            "--config-directory",
             "--keys-directory",
             metavar="DIRECTORY",
-            help="Used with 'generate-*' options to specify where files such as"
-            " signing keys should be stored, unless explicitly"
-            " specified in the config.",
+            help=(
+                "Specify where additional config files such as signing keys and log"
+                " config should be stored. Defaults to the same directory as the last"
+                " config file."
+            ),
         )
-        config_parser.add_argument(
-            "-H", "--server-name", help="The server name to generate a config file for"
+        generate_group.add_argument(
+            "--data-directory",
+            metavar="DIRECTORY",
+            help=(
+                "Specify where data such as the media store and database file should be"
+                " stored. Defaults to the current working directory."
+            ),
+        )
+        generate_group.add_argument(
+            "--open-private-ports",
+            action="store_true",
+            help=(
+                "Leave private ports (such as the non-TLS HTTP listener) open to the"
+                " internet. Do not use this unless you know what you are doing."
+            ),
         )
+
         config_args, remaining_args = config_parser.parse_known_args(argv)
 
         config_files = find_config_files(search_paths=config_args.config_path)
 
-        generate_keys = config_args.generate_keys
+        if not config_files:
+            config_parser.error(
+                "Must supply a config file.\nA config file can be automatically"
+                ' generated using "--generate-config -H SERVER_NAME'
+                ' -c CONFIG-FILE"'
+            )
+
+        if config_args.config_directory:
+            config_dir_path = config_args.config_directory
+        else:
+            config_dir_path = os.path.dirname(config_files[-1])
+        config_dir_path = os.path.abspath(config_dir_path)
+        data_dir_path = os.getcwd()
+
+        generate_missing_configs = config_args.generate_missing_configs
 
         obj = cls()
 
@@ -281,19 +340,16 @@ class Config(object):
                     "Please specify either --report-stats=yes or --report-stats=no\n\n"
                     + MISSING_REPORT_STATS_SPIEL
                 )
-            if not config_files:
-                config_parser.error(
-                    "Must supply a config file.\nA config file can be automatically"
-                    " generated using \"--generate-config -H SERVER_NAME"
-                    " -c CONFIG-FILE\""
-                )
+
             (config_path,) = config_files
             if not cls.path_exists(config_path):
-                if config_args.keys_directory:
-                    config_dir_path = config_args.keys_directory
+                print("Generating config file %s" % (config_path,))
+
+                if config_args.data_directory:
+                    data_dir_path = config_args.data_directory
                 else:
-                    config_dir_path = os.path.dirname(config_path)
-                config_dir_path = os.path.abspath(config_dir_path)
+                    data_dir_path = os.getcwd()
+                data_dir_path = os.path.abspath(data_dir_path)
 
                 server_name = config_args.server_name
                 if not server_name:
@@ -304,22 +360,21 @@ class Config(object):
 
                 config_str = obj.generate_config(
                     config_dir_path=config_dir_path,
-                    data_dir_path=os.getcwd(),
+                    data_dir_path=data_dir_path,
                     server_name=server_name,
                     report_stats=(config_args.report_stats == "yes"),
                     generate_secrets=True,
+                    open_private_ports=config_args.open_private_ports,
                 )
 
                 if not cls.path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "w") as config_file:
-                    config_file.write(
-                        "# vim:ft=yaml\n\n"
-                    )
+                    config_file.write("# vim:ft=yaml\n\n")
                     config_file.write(config_str)
 
-                config = yaml.safe_load(config_str)
-                obj.invoke_all("generate_files", config)
+                config_dict = yaml.safe_load(config_str)
+                obj.generate_missing_files(config_dict, config_dir_path)
 
                 print(
                     (
@@ -333,12 +388,12 @@ class Config(object):
             else:
                 print(
                     (
-                        "Config file %r already exists. Generating any missing key"
+                        "Config file %r already exists. Generating any missing config"
                         " files."
                     )
                     % (config_path,)
                 )
-                generate_keys = True
+                generate_missing_configs = True
 
         parser = argparse.ArgumentParser(
             parents=[config_parser],
@@ -349,66 +404,63 @@ class Config(object):
         obj.invoke_all("add_arguments", parser)
         args = parser.parse_args(remaining_args)
 
-        if not config_files:
-            config_parser.error(
-                "Must supply a config file.\nA config file can be automatically"
-                " generated using \"--generate-config -H SERVER_NAME"
-                " -c CONFIG-FILE\""
-            )
-
-        obj.read_config_files(
-            config_files,
-            keys_directory=config_args.keys_directory,
-            generate_keys=generate_keys,
-        )
-
-        if generate_keys:
+        config_dict = read_config_files(config_files)
+        if generate_missing_configs:
+            obj.generate_missing_files(config_dict, config_dir_path)
             return None
 
+        obj.parse_config_dict(
+            config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
+        )
         obj.invoke_all("read_arguments", args)
 
         return obj
 
-    def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
-        if not keys_directory:
-            keys_directory = os.path.dirname(config_files[-1])
+    def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+        """Read the information from the config dict into this Config object.
 
-        self.config_dir_path = os.path.abspath(keys_directory)
-
-        specified_config = {}
-        for config_file in config_files:
-            yaml_config = self.read_config_file(config_file)
-            specified_config.update(yaml_config)
+        Args:
+            config_dict (dict): Configuration data, as read from the yaml
 
-        if "server_name" not in specified_config:
-            raise ConfigError(MISSING_SERVER_NAME)
+            config_dir_path (str): The path where the config files are kept. Used to
+                create filenames for things like the log config and the signing key.
 
-        server_name = specified_config["server_name"]
-        config_string = self.generate_config(
-            config_dir_path=self.config_dir_path,
-            data_dir_path=os.getcwd(),
-            server_name=server_name,
-            generate_secrets=False,
+            data_dir_path (str): The path where the data files are kept. Used to create
+                filenames for things like the database and media store.
+        """
+        self.invoke_all(
+            "read_config",
+            config_dict,
+            config_dir_path=config_dir_path,
+            data_dir_path=data_dir_path,
         )
-        config = yaml.safe_load(config_string)
-        config.pop("log_config")
-        config.update(specified_config)
 
-        if "report_stats" not in config:
-            raise ConfigError(
-                MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
-                + "\n"
-                + MISSING_REPORT_STATS_SPIEL
-            )
+    def generate_missing_files(self, config_dict, config_dir_path):
+        self.invoke_all("generate_files", config_dict, config_dir_path)
 
-        if generate_keys:
-            self.invoke_all("generate_files", config)
-            return
 
-        self.parse_config_dict(config)
+def read_config_files(config_files):
+    """Read the config files into a dict
 
-    def parse_config_dict(self, config_dict):
-        self.invoke_all("read_config", config_dict)
+    Args:
+        config_files (iterable[str]): A list of the config files to read
+
+    Returns: dict
+    """
+    specified_config = {}
+    for config_file in config_files:
+        with open(config_file) as file_stream:
+            yaml_config = yaml.safe_load(file_stream)
+        specified_config.update(yaml_config)
+
+    if "server_name" not in specified_config:
+        raise ConfigError(MISSING_SERVER_NAME)
+
+    if "report_stats" not in specified_config:
+        raise ConfigError(
+            MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
+        )
+    return specified_config
 
 
 def find_config_files(search_paths):