summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_base.py58
-rw-r--r--synapse/config/homeserver.py2
2 files changed, 38 insertions, 22 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 9f5da70948..d98b6aaedf 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -117,51 +117,59 @@ class Config(object):
 
         config = yaml.load(default_config)
 
-        if not os.path.exists(config_dir_path):
-            os.makedirs(config_dir_path)
-
-        self.invoke_all("generate_keys", config)
-
-        return default_config
+        return default_config, config
 
     @classmethod
     def load_config(cls, description, argv, generate_section=None):
-        result = cls()
+        obj = cls()
 
         config_parser = argparse.ArgumentParser(add_help=False)
         config_parser.add_argument(
             "-c", "--config-path",
+            action="append",
             metavar="CONFIG_FILE",
             help="Specify config file"
         )
         config_parser.add_argument(
             "--generate-config",
-            metavar="SERVER_NAME",
+            action="store_true",
             help="Generate a config file for the server name"
         )
+        config_parser.add_argument(
+            "-H", "--server-name",
+            help="The server name to generate a config file for"
+        )
         config_args, remaining_args = config_parser.parse_known_args(argv)
 
         if not config_args.config_path:
             config_parser.error(
                 "Must supply a config file.\nA config file can be automatically"
-                " generated using \"--generate-config SERVER_NAME"
+                " generated using \"--generate-config -h SERVER_NAME"
                 " -c CONFIG-FILE\""
             )
 
+        config_dir_path = os.path.dirname(config_args.config_path[0])
+        config_dir_path = os.path.abspath(config_dir_path)
         if config_args.generate_config:
-            server_name = config_args.generate_config
-            config_path = config_args.config_path
+            server_name = config_args.server_name
+            if not server_name:
+                print "Most specify a server_name to a generate config for."
+                sys.exit(1)
+            (config_path,) = config_args.config_path
             if os.path.exists(config_path):
                 print "Config file %r already exists. Not overwriting" % (
                     config_args.config_path
                 )
-                sys.exit(0)
-            config_dir_path = os.path.dirname(config_args.config_path)
-            config_dir_path = os.path.abspath(config_dir_path)
+                sys.exit(1)
+            if not os.path.exists(config_dir_path):
+                os.makedirs(config_dir_path)
             with open(config_path, "wb") as config_file:
-                config_file.write(
-                    result.generate_config(config_dir_path, server_name)
+
+                config_bytes, config = obj.generate_config(
+                    config_dir_path, server_name
                 )
+                obj.invoke_all("generate_keys", config)
+                config_file.write(config_bytes)
             print (
                 "A config file has been generated in %s for server name"
                 " '%s' with corresponding SSL keys and self-signed"
@@ -174,8 +182,16 @@ class Config(object):
             )
             sys.exit(0)
 
-        config = cls.read_config_file(config_args.config_path)
-        result.invoke_all("read_config", config)
+        specified_config = {}
+        for config_path in config_args.config_path:
+            yaml_config = cls.read_config_file(config_path)
+            specified_config.update(yaml_config)
+
+        server_name = specified_config["server_name"]
+        _, config = obj.generate_config(config_dir_path, server_name)
+        config.update(specified_config)
+
+        obj.invoke_all("read_config", config)
 
         parser = argparse.ArgumentParser(
             parents=[config_parser],
@@ -183,9 +199,9 @@ class Config(object):
             formatter_class=argparse.RawDescriptionHelpFormatter,
         )
 
-        result.invoke_all("add_arguments", parser)
+        obj.invoke_all("add_arguments", parser)
         args = parser.parse_args(remaining_args)
 
-        result.invoke_all("read_arguments", args)
+        obj.invoke_all("read_arguments", args)
 
-        return result
+        return obj
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index f9b4807a35..fe0ccb6eb7 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -37,5 +37,5 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
 if __name__ == '__main__':
     import sys
     sys.stdout.write(
-        HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])
+        HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
     )