summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a9304a11ba..15d78ff33a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -17,7 +17,6 @@ import argparse
 import errno
 import os
 import yaml
-import sys
 from textwrap import dedent
 
 
@@ -136,13 +135,20 @@ class Config(object):
                 results.append(getattr(cls, name)(self, *args, **kargs))
         return results
 
-    def generate_config(self, config_dir_path, server_name, report_stats=None):
+    def generate_config(
+            self,
+            config_dir_path,
+            server_name,
+            is_generating_file,
+            report_stats=None,
+    ):
         default_config = "# vim:ft=yaml\n"
 
         default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
             "default_config",
             config_dir_path=config_dir_path,
             server_name=server_name,
+            is_generating_file=is_generating_file,
             report_stats=report_stats,
         ))
 
@@ -244,8 +250,10 @@ class Config(object):
 
                 server_name = config_args.server_name
                 if not server_name:
-                    print "Must specify a server_name to a generate config for."
-                    sys.exit(1)
+                    raise ConfigError(
+                        "Must specify a server_name to a generate config for."
+                        " Pass -H server.name."
+                    )
                 if not os.path.exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "wb") as config_file:
@@ -253,6 +261,7 @@ class Config(object):
                         config_dir_path=config_dir_path,
                         server_name=server_name,
                         report_stats=(config_args.report_stats == "yes"),
+                        is_generating_file=True
                     )
                     obj.invoke_all("generate_files", config)
                     config_file.write(config_bytes)
@@ -266,7 +275,7 @@ class Config(object):
                     "If this server name is incorrect, you will need to"
                     " regenerate the SSL certificates"
                 )
-                sys.exit(0)
+                return
             else:
                 print (
                     "Config file %r already exists. Generating any missing key"
@@ -302,25 +311,25 @@ class Config(object):
             specified_config.update(yaml_config)
 
         if "server_name" not in specified_config:
-            sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
-            sys.exit(1)
+            raise ConfigError(MISSING_SERVER_NAME)
 
         server_name = specified_config["server_name"]
         _, config = obj.generate_config(
             config_dir_path=config_dir_path,
-            server_name=server_name
+            server_name=server_name,
+            is_generating_file=False,
         )
         config.pop("log_config")
         config.update(specified_config)
         if "report_stats" not in config:
-            sys.stderr.write(
-                "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
-                MISSING_REPORT_STATS_SPIEL + "\n")
-            sys.exit(1)
+            raise ConfigError(
+                MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
+                MISSING_REPORT_STATS_SPIEL
+            )
 
         if generate_keys:
             obj.invoke_all("generate_files", config)
-            sys.exit(0)
+            return
 
         obj.invoke_all("read_config", config)