summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py184
1 files changed, 113 insertions, 71 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index b59f4e45e2..cd4bd28e8c 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -14,9 +14,10 @@
 # limitations under the License.
 
 import argparse
-import sys
 import os
 import yaml
+import sys
+from textwrap import dedent
 
 
 class ConfigError(Exception):
@@ -24,18 +25,35 @@ class ConfigError(Exception):
 
 
 class Config(object):
-    def __init__(self, args):
-        pass
 
     @staticmethod
-    def parse_size(string):
+    def parse_size(value):
+        if isinstance(value, int) or isinstance(value, long):
+            return value
         sizes = {"K": 1024, "M": 1024 * 1024}
         size = 1
-        suffix = string[-1]
+        suffix = value[-1]
         if suffix in sizes:
-            string = string[:-1]
+            value = value[:-1]
             size = sizes[suffix]
-        return int(string) * size
+        return int(value) * size
+
+    @staticmethod
+    def parse_duration(value):
+        if isinstance(value, int) or isinstance(value, long):
+            return value
+        second = 1000
+        hour = 60 * 60 * second
+        day = 24 * hour
+        week = 7 * day
+        year = 365 * day
+        sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
+        size = 1
+        suffix = value[-1]
+        if suffix in sizes:
+            value = value[:-1]
+            size = sizes[suffix]
+        return int(value) * size
 
     @staticmethod
     def abspath(file_path):
@@ -77,17 +95,6 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
-    @classmethod
-    def read_yaml_file(cls, file_path, config_name):
-        cls.check_file(file_path, config_name)
-        with open(file_path) as file_stream:
-            try:
-                return yaml.load(file_stream)
-            except:
-                raise ConfigError(
-                    "Error parsing yaml in file %r" % (file_path,)
-                )
-
     @staticmethod
     def default_path(name):
         return os.path.abspath(os.path.join(os.path.curdir, name))
@@ -97,84 +104,119 @@ class Config(object):
         with open(file_path) as file_stream:
             return yaml.load(file_stream)
 
-    @classmethod
-    def add_arguments(cls, parser):
-        pass
+    def invoke_all(self, name, *args, **kargs):
+        results = []
+        for cls in type(self).mro():
+            if name in cls.__dict__:
+                results.append(getattr(cls, name)(self, *args, **kargs))
+        return results
 
-    @classmethod
-    def generate_config(cls, args, config_dir_path):
-        pass
+    def generate_config(self, config_dir_path, server_name):
+        default_config = "# vim:ft=yaml\n"
+
+        default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
+            "default_config", config_dir_path, server_name
+        ))
+
+        config = yaml.load(default_config)
+
+        return default_config, config
 
     @classmethod
     def load_config(cls, description, argv, generate_section=None):
+        obj = cls()
+
         config_parser = argparse.ArgumentParser(add_help=False)
         config_parser.add_argument(
             "-c", "--config-path",
+            action="append",
             metavar="CONFIG_FILE",
             help="Specify config file"
         )
         config_parser.add_argument(
             "--generate-config",
             action="store_true",
-            help="Generate config file"
+            help="Generate a config file for the server name"
+        )
+        config_parser.add_argument(
+            "-H", "--server-name",
+            help="The server name to generate a config file for"
         )
         config_args, remaining_args = config_parser.parse_known_args(argv)
 
+        if not config_args.config_path:
+            config_parser.error(
+                "Must supply a config file.\nA config file can be automatically"
+                " generated using \"--generate-config -h SERVER_NAME"
+                " -c CONFIG-FILE\""
+            )
+
+        config_dir_path = os.path.dirname(config_args.config_path[0])
+        config_dir_path = os.path.abspath(config_dir_path)
         if config_args.generate_config:
-            if not config_args.config_path:
-                config_parser.error(
-                    "Must specify where to generate the config file"
+            server_name = config_args.server_name
+            if not server_name:
+                print "Most specify a server_name to a generate config for."
+                sys.exit(1)
+            (config_path,) = config_args.config_path
+            if not os.path.exists(config_dir_path):
+                os.makedirs(config_dir_path)
+            if os.path.exists(config_path):
+                print "Config file %r already exists" % (config_path,)
+                yaml_config = cls.read_config_file(config_path)
+                yaml_name = yaml_config["server_name"]
+                if server_name != yaml_name:
+                    print (
+                        "Config file %r has a different server_name: "
+                        " %r != %r" % (config_path, server_name, yaml_name)
+                    )
+                    sys.exit(1)
+                config_bytes, config = obj.generate_config(
+                    config_dir_path, server_name
+                )
+                config.update(yaml_config)
+                print "Generating any missing keys for %r" % (server_name,)
+                obj.invoke_all("generate_files", config)
+                sys.exit(0)
+            with open(config_path, "wb") as config_file:
+                config_bytes, config = obj.generate_config(
+                    config_dir_path, server_name
                 )
-            config_dir_path = os.path.dirname(config_args.config_path)
-            if os.path.exists(config_args.config_path):
-                defaults = cls.read_config_file(config_args.config_path)
-            else:
-                defaults = {}
-        else:
-            if config_args.config_path:
-                defaults = cls.read_config_file(config_args.config_path)
-            else:
-                defaults = {}
+                obj.invoke_all("generate_files", config)
+                config_file.write(config_bytes)
+                print (
+                    "A config file has been generated in %s for server name"
+                    " '%s' with corresponding SSL keys and self-signed"
+                    " certificates. Please review this file and customise it to"
+                    " your needs."
+                ) % (config_path, server_name)
+            print (
+                "If this server name is incorrect, you will need to regenerate"
+                " the SSL certificates"
+            )
+            sys.exit(0)
+
+        specified_config = {}
+        for config_path in config_args.config_path:
+            yaml_config = cls.read_config_file(config_path)
+            specified_config.update(yaml_config)
+
+        server_name = specified_config["server_name"]
+        _, config = obj.generate_config(config_dir_path, server_name)
+        config.pop("log_config")
+        config.update(specified_config)
+
+        obj.invoke_all("read_config", config)
 
         parser = argparse.ArgumentParser(
             parents=[config_parser],
             description=description,
             formatter_class=argparse.RawDescriptionHelpFormatter,
         )
-        cls.add_arguments(parser)
-        parser.set_defaults(**defaults)
 
+        obj.invoke_all("add_arguments", parser)
         args = parser.parse_args(remaining_args)
 
-        if config_args.generate_config:
-            config_dir_path = os.path.dirname(config_args.config_path)
-            config_dir_path = os.path.abspath(config_dir_path)
-            if not os.path.exists(config_dir_path):
-                os.makedirs(config_dir_path)
-            cls.generate_config(args, config_dir_path)
-            config = {}
-            for key, value in vars(args).items():
-                if (key not in set(["config_path", "generate_config"])
-                        and value is not None):
-                    config[key] = value
-            with open(config_args.config_path, "w") as config_file:
-                # TODO(mark/paul) We might want to output emacs-style mode
-                # markers as well as vim-style mode markers into the file,
-                # to further hint to people this is a YAML file.
-                config_file.write("# vim:ft=yaml\n")
-                yaml.dump(config, config_file, default_flow_style=False)
-            print (
-                "A config file has been generated in %s for server name"
-                " '%s' with corresponding SSL keys and self-signed"
-                " certificates. Please review this file and customise it to"
-                " your needs."
-            ) % (
-                config_args.config_path, config['server_name']
-            )
-            print (
-                "If this server name is incorrect, you will need to regenerate"
-                " the SSL certificates"
-            )
-            sys.exit(0)
+        obj.invoke_all("read_arguments", args)
 
-        return cls(args)
+        return obj