summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py20
-rw-r--r--synapse/config/__main__.py7
-rw-r--r--synapse/config/_base.py35
-rw-r--r--synapse/config/registration.py18
-rw-r--r--tests/config/__init__.py14
-rw-r--r--tests/config/test_generate.py50
-rw-r--r--tests/config/test_load.py77
7 files changed, 198 insertions, 23 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0a6a19033d..89238cb7e3 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -24,6 +24,7 @@ import resource
 import subprocess
 import sys
 import time
+from synapse.config._base import ConfigError
 
 from synapse.python_dependencies import (
     check_requirements, DEPENDENCY_LINKS
@@ -350,11 +351,20 @@ def setup(config_options):
     Returns:
         HomeServer
     """
-    config = HomeServerConfig.load_config(
-        "Synapse Homeserver",
-        config_options,
-        generate_section="Homeserver"
-    )
+    try:
+        config = HomeServerConfig.load_config(
+            "Synapse Homeserver",
+            config_options,
+            generate_section="Homeserver"
+        )
+    except ConfigError as e:
+        sys.stderr.write("\n" + e.message + "\n")
+        sys.exit(1)
+
+    if not config:
+        # If a config isn't returned, and an exception isn't raised, we're just
+        # generating config files and shouldn't try to continue.
+        sys.exit(0)
 
     config.setup_logging()
 
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index ea9e7907a6..0a3b70e11f 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.config._base import ConfigError
 
 if __name__ == "__main__":
     import sys
@@ -21,7 +22,11 @@ if __name__ == "__main__":
 
     if action == "read":
         key = sys.argv[2]
-        config = HomeServerConfig.load_config("", sys.argv[3:])
+        try:
+            config = HomeServerConfig.load_config("", sys.argv[3:])
+        except ConfigError as e:
+            sys.stderr.write("\n" + e.message + "\n")
+            sys.exit(1)
 
         print getattr(config, key)
         sys.exit(0)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a9304a11ba..15d78ff33a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -17,7 +17,6 @@ import argparse
 import errno
 import os
 import yaml
-import sys
 from textwrap import dedent
 
 
@@ -136,13 +135,20 @@ class Config(object):
                 results.append(getattr(cls, name)(self, *args, **kargs))
         return results
 
-    def generate_config(self, config_dir_path, server_name, report_stats=None):
+    def generate_config(
+            self,
+            config_dir_path,
+            server_name,
+            is_generating_file,
+            report_stats=None,
+    ):
         default_config = "# vim:ft=yaml\n"
 
         default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
             "default_config",
             config_dir_path=config_dir_path,
             server_name=server_name,
+            is_generating_file=is_generating_file,
             report_stats=report_stats,
         ))
 
@@ -244,8 +250,10 @@ class Config(object):
 
                 server_name = config_args.server_name
                 if not server_name:
-                    print "Must specify a server_name to a generate config for."
-                    sys.exit(1)
+                    raise ConfigError(
+                        "Must specify a server_name to a generate config for."
+                        " Pass -H server.name."
+                    )
                 if not os.path.exists(config_dir_path):
                     os.makedirs(config_dir_path)
                 with open(config_path, "wb") as config_file:
@@ -253,6 +261,7 @@ class Config(object):
                         config_dir_path=config_dir_path,
                         server_name=server_name,
                         report_stats=(config_args.report_stats == "yes"),
+                        is_generating_file=True
                     )
                     obj.invoke_all("generate_files", config)
                     config_file.write(config_bytes)
@@ -266,7 +275,7 @@ class Config(object):
                     "If this server name is incorrect, you will need to"
                     " regenerate the SSL certificates"
                 )
-                sys.exit(0)
+                return
             else:
                 print (
                     "Config file %r already exists. Generating any missing key"
@@ -302,25 +311,25 @@ class Config(object):
             specified_config.update(yaml_config)
 
         if "server_name" not in specified_config:
-            sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
-            sys.exit(1)
+            raise ConfigError(MISSING_SERVER_NAME)
 
         server_name = specified_config["server_name"]
         _, config = obj.generate_config(
             config_dir_path=config_dir_path,
-            server_name=server_name
+            server_name=server_name,
+            is_generating_file=False,
         )
         config.pop("log_config")
         config.update(specified_config)
         if "report_stats" not in config:
-            sys.stderr.write(
-                "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
-                MISSING_REPORT_STATS_SPIEL + "\n")
-            sys.exit(1)
+            raise ConfigError(
+                MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
+                MISSING_REPORT_STATS_SPIEL
+            )
 
         if generate_keys:
             obj.invoke_all("generate_files", config)
-            sys.exit(0)
+            return
 
         obj.invoke_all("read_config", config)
 
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 90ea19bd4b..9b6dacc5b8 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,12 +33,24 @@ class RegistrationConfig(Config):
 
         self.registration_shared_secret = config.get("registration_shared_secret")
         self.macaroon_secret_key = config.get("macaroon_secret_key")
+        if self.macaroon_secret_key is None:
+            raise Exception(
+                "Config is missing missing macaroon_secret_key - please set it"
+                " in your config file."
+            )
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
         self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
         self.allow_guest_access = config.get("allow_guest_access", False)
 
-    def default_config(self, **kwargs):
+    def default_config(self, is_generating_file=False, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
+
+        macaroon_line = ""
+        if is_generating_file:
+            macaroon_line += '\n        macaroon_secret_key: "%s"\n' % (
+                random_string_with_symbols(50),
+            )
+
         macaroon_secret_key = random_string_with_symbols(50)
         return """\
         ## Registration ##
@@ -49,9 +61,7 @@ class RegistrationConfig(Config):
         # If set, allows registration by anyone who also has the shared
         # secret, even if registration is otherwise disabled.
         registration_shared_secret: "%(registration_shared_secret)s"
-
-        macaroon_secret_key: "%(macaroon_secret_key)s"
-
+%(macaroon_line)s
         # Set the number of bcrypt rounds used to generate password hash.
         # Larger numbers increase the work factor needed to generate the hash.
         # The default number of rounds is 12.
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/config/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
new file mode 100644
index 0000000000..4329d73974
--- /dev/null
+++ b/tests/config/test_generate.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+import shutil
+import tempfile
+from synapse.config.homeserver import HomeServerConfig
+from tests import unittest
+
+
+class ConfigGenerationTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.dir = tempfile.mkdtemp()
+        print self.dir
+        self.file = os.path.join(self.dir, "homeserver.yaml")
+
+    def tearDown(self):
+        shutil.rmtree(self.dir)
+
+    def test_generate_config_generates_files(self):
+        HomeServerConfig.load_config("", [
+            "--generate-config",
+            "-c", self.file,
+            "--report-stats=yes",
+            "-H", "lemurs.win"
+        ])
+
+        self.assertSetEqual(
+            set([
+                "homeserver.yaml",
+                "lemurs.win.log.config",
+                "lemurs.win.signing.key",
+                "lemurs.win.tls.crt",
+                "lemurs.win.tls.dh",
+                "lemurs.win.tls.key",
+            ]),
+            set(os.listdir(self.dir))
+        )
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
new file mode 100644
index 0000000000..7f41279715
--- /dev/null
+++ b/tests/config/test_load.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+import shutil
+import tempfile
+import yaml
+from synapse.config.homeserver import HomeServerConfig
+from tests import unittest
+
+
+class ConfigLoadingTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.dir = tempfile.mkdtemp()
+        print self.dir
+        self.file = os.path.join(self.dir, "homeserver.yaml")
+
+    def tearDown(self):
+        shutil.rmtree(self.dir)
+
+    def test_load_fails_if_server_name_missing(self):
+        self.generate_config_and_remove_lines_containing("server_name")
+        with self.assertRaises(Exception):
+            HomeServerConfig.load_config("", ["-c", self.file])
+
+    def test_generates_and_loads_macaroon_secret_key(self):
+        self.generate_config()
+
+        with open(self.file,
+                  "r") as f:
+            raw = yaml.load(f)
+        self.assertIn("macaroon_secret_key", raw)
+
+        config = HomeServerConfig.load_config("", ["-c", self.file])
+        self.assertTrue(
+            hasattr(config, "macaroon_secret_key"),
+            "Want config to have attr macaroon_secret_key"
+        )
+        if len(config.macaroon_secret_key) < 5:
+            self.fail(
+                "Want macaroon secret key to be string of at least length 5,"
+                "was: %r" % (config.macaroon_secret_key,)
+            )
+
+    def test_load_fails_if_macaroon_secret_key_missing(self):
+        self.generate_config_and_remove_lines_containing("macaroon")
+        with self.assertRaises(Exception):
+            HomeServerConfig.load_config("", ["-c", self.file])
+
+    def generate_config(self):
+        HomeServerConfig.load_config("", [
+            "--generate-config",
+            "-c", self.file,
+            "--report-stats=yes",
+            "-H", "lemurs.win"
+        ])
+
+    def generate_config_and_remove_lines_containing(self, needle):
+        self.generate_config()
+
+        with open(self.file, "r") as f:
+            contents = f.readlines()
+        contents = [l for l in contents if needle not in l]
+        with open(self.file, "w") as f:
+            f.write("".join(contents))