summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py3
-rw-r--r--synapse/config/_base.py147
-rw-r--r--tests/config/test_generate.py2
-rw-r--r--tests/config/test_load.py22
4 files changed, 126 insertions, 48 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 22e1721fc4..40ffd9bf0d 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -266,10 +266,9 @@ def setup(config_options):
         HomeServer
     """
     try:
-        config = HomeServerConfig.load_config(
+        config = HomeServerConfig.load_or_generate_config(
             "Synapse Homeserver",
             config_options,
-            generate_section="Homeserver"
         )
     except ConfigError as e:
         sys.stderr.write("\n" + e.message + "\n")
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7449f36491..af9f17bf7b 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -157,9 +157,40 @@ class Config(object):
         return default_config, config
 
     @classmethod
-    def load_config(cls, description, argv, generate_section=None):
+    def load_config(cls, description, argv):
+        config_parser = argparse.ArgumentParser(
+            description=description,
+        )
+        config_parser.add_argument(
+            "-c", "--config-path",
+            action="append",
+            metavar="CONFIG_FILE",
+            help="Specify config file. Can be given multiple times and"
+                 " may specify directories containing *.yaml files."
+        )
+
+        config_parser.add_argument(
+            "--keys-directory",
+            metavar="DIRECTORY",
+            help="Where files such as certs and signing keys are stored when"
+                 " their location is given explicitly in the config."
+                 " Defaults to the directory containing the last config file",
+        )
+
+        config_args = config_parser.parse_args(argv)
+
+        config_files = find_config_files(search_paths=config_args.config_path)
+
         obj = cls()
+        obj.read_config_files(
+            config_files,
+            keys_directory=config_args.keys_directory,
+            generate_keys=False,
+        )
+        return obj
 
+    @classmethod
+    def load_or_generate_config(cls, description, argv):
         config_parser = argparse.ArgumentParser(add_help=False)
         config_parser.add_argument(
             "-c", "--config-path",
@@ -176,7 +207,7 @@ class Config(object):
         config_parser.add_argument(
             "--report-stats",
             action="store",
-            help="Stuff",
+            help="Whether the generated config reports anonymized usage statistics",
             choices=["yes", "no"]
         )
         config_parser.add_argument(
@@ -197,36 +228,11 @@ class Config(object):
         )
         config_args, remaining_args = config_parser.parse_known_args(argv)
 
+        config_files = find_config_files(search_paths=config_args.config_path)
+
         generate_keys = config_args.generate_keys
 
-        config_files = []
-        if config_args.config_path:
-            for config_path in config_args.config_path:
-                if os.path.isdir(config_path):
-                    # We accept specifying directories as config paths, we search
-                    # inside that directory for all files matching *.yaml, and then
-                    # we apply them in *sorted* order.
-                    files = []
-                    for entry in os.listdir(config_path):
-                        entry_path = os.path.join(config_path, entry)
-                        if not os.path.isfile(entry_path):
-                            print (
-                                "Found subdirectory in config directory: %r. IGNORING."
-                            ) % (entry_path, )
-                            continue
-
-                        if not entry.endswith(".yaml"):
-                            print (
-                                "Found file in config directory that does not"
-                                " end in '.yaml': %r. IGNORING."
-                            ) % (entry_path, )
-                            continue
-
-                        files.append(entry_path)
-
-                    config_files.extend(sorted(files))
-                else:
-                    config_files.append(config_path)
+        obj = cls()
 
         if config_args.generate_config:
             if config_args.report_stats is None:
@@ -299,28 +305,43 @@ class Config(object):
                 " -c CONFIG-FILE\""
             )
 
-        if config_args.keys_directory:
-            config_dir_path = config_args.keys_directory
-        else:
-            config_dir_path = os.path.dirname(config_args.config_path[-1])
-        config_dir_path = os.path.abspath(config_dir_path)
+        obj.read_config_files(
+            config_files,
+            keys_directory=config_args.keys_directory,
+            generate_keys=generate_keys,
+        )
+
+        if generate_keys:
+            return None
+
+        obj.invoke_all("read_arguments", args)
+
+        return obj
+
+    def read_config_files(self, config_files, keys_directory=None,
+                          generate_keys=False):
+        if not keys_directory:
+            keys_directory = os.path.dirname(config_files[-1])
+
+        config_dir_path = os.path.abspath(keys_directory)
 
         specified_config = {}
         for config_file in config_files:
-            yaml_config = cls.read_config_file(config_file)
+            yaml_config = self.read_config_file(config_file)
             specified_config.update(yaml_config)
 
         if "server_name" not in specified_config:
             raise ConfigError(MISSING_SERVER_NAME)
 
         server_name = specified_config["server_name"]
-        _, config = obj.generate_config(
+        _, config = self.generate_config(
             config_dir_path=config_dir_path,
             server_name=server_name,
             is_generating_file=False,
         )
         config.pop("log_config")
         config.update(specified_config)
+
         if "report_stats" not in config:
             raise ConfigError(
                 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
@@ -328,11 +349,51 @@ class Config(object):
             )
 
         if generate_keys:
-            obj.invoke_all("generate_files", config)
+            self.invoke_all("generate_files", config)
             return
 
-        obj.invoke_all("read_config", config)
-
-        obj.invoke_all("read_arguments", args)
-
-        return obj
+        self.invoke_all("read_config", config)
+
+
+def find_config_files(search_paths):
+    """Finds config files using a list of search paths. If a path is a file
+    then that file path is added to the list. If a search path is a directory
+    then all the "*.yaml" files in that directory are added to the list in
+    sorted order.
+
+    Args:
+        search_paths(list(str)): A list of paths to search.
+
+    Returns:
+        list(str): A list of file paths.
+    """
+
+    config_files = []
+    if search_paths:
+        for config_path in search_paths:
+            if os.path.isdir(config_path):
+                # We accept specifying directories as config paths, we search
+                # inside that directory for all files matching *.yaml, and then
+                # we apply them in *sorted* order.
+                files = []
+                for entry in os.listdir(config_path):
+                    entry_path = os.path.join(config_path, entry)
+                    if not os.path.isfile(entry_path):
+                        print (
+                            "Found subdirectory in config directory: %r. IGNORING."
+                        ) % (entry_path, )
+                        continue
+
+                    if not entry.endswith(".yaml"):
+                        print (
+                            "Found file in config directory that does not"
+                            " end in '.yaml': %r. IGNORING."
+                        ) % (entry_path, )
+                        continue
+
+                    files.append(entry_path)
+
+                config_files.extend(sorted(files))
+            else:
+                config_files.append(config_path)
+    return config_files
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 4329d73974..8f57fbeb23 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -30,7 +30,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
         shutil.rmtree(self.dir)
 
     def test_generate_config_generates_files(self):
-        HomeServerConfig.load_config("", [
+        HomeServerConfig.load_or_generate_config("", [
             "--generate-config",
             "-c", self.file,
             "--report-stats=yes",
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index bf46233c5c..161a87d7e3 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -34,6 +34,8 @@ class ConfigLoadingTestCase(unittest.TestCase):
         self.generate_config_and_remove_lines_containing("server_name")
         with self.assertRaises(Exception):
             HomeServerConfig.load_config("", ["-c", self.file])
+        with self.assertRaises(Exception):
+            HomeServerConfig.load_or_generate_config("", ["-c", self.file])
 
     def test_generates_and_loads_macaroon_secret_key(self):
         self.generate_config()
@@ -54,11 +56,24 @@ class ConfigLoadingTestCase(unittest.TestCase):
                 "was: %r" % (config.macaroon_secret_key,)
             )
 
+        config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+        self.assertTrue(
+            hasattr(config, "macaroon_secret_key"),
+            "Want config to have attr macaroon_secret_key"
+        )
+        if len(config.macaroon_secret_key) < 5:
+            self.fail(
+                "Want macaroon secret key to be string of at least length 5,"
+                "was: %r" % (config.macaroon_secret_key,)
+            )
+
     def test_load_succeeds_if_macaroon_secret_key_missing(self):
         self.generate_config_and_remove_lines_containing("macaroon")
         config1 = HomeServerConfig.load_config("", ["-c", self.file])
         config2 = HomeServerConfig.load_config("", ["-c", self.file])
+        config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
         self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
+        self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
 
     def test_disable_registration(self):
         self.generate_config()
@@ -70,14 +85,17 @@ class ConfigLoadingTestCase(unittest.TestCase):
         config = HomeServerConfig.load_config("", ["-c", self.file])
         self.assertFalse(config.enable_registration)
 
+        config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+        self.assertFalse(config.enable_registration)
+
         # Check that either config value is clobbered by the command line.
-        config = HomeServerConfig.load_config("", [
+        config = HomeServerConfig.load_or_generate_config("", [
             "-c", self.file, "--enable-registration"
         ])
         self.assertTrue(config.enable_registration)
 
     def generate_config(self):
-        HomeServerConfig.load_config("", [
+        HomeServerConfig.load_or_generate_config("", [
             "--generate-config",
             "-c", self.file,
             "--report-stats=yes",