summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Kaye <1917473+michaelkaye@users.noreply.github.com>2018-06-22 16:54:59 +0100
committerMichael Kaye <1917473+michaelkaye@users.noreply.github.com>2018-06-22 16:54:59 +0100
commitc18011621541e6487cadfb158a9999ab268c902f (patch)
tree593f1fb20c3a9f46f7d9d564474969d8daa21879
parentMerge branch 'master' into develop (diff)
downloadsynapse-michaelkaye/synapse_config_check.tar.xz
-rw-r--r--synapse/config/_base.py11
-rw-r--r--tests/config/test_check.py72
2 files changed, 82 insertions, 1 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index b748ed2b0a..66034a386c 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -224,6 +224,11 @@ class Config(object):
             help="Generate a config file for the server name"
         )
         config_parser.add_argument(
+            "--check-config",
+            action="store_true",
+            help="Check configuration supplied is valid"
+        )
+        config_parser.add_argument(
             "--report-stats",
             action="store",
             help="Whether the generated config reports anonymized usage statistics",
@@ -250,6 +255,8 @@ class Config(object):
         config_files = find_config_files(search_paths=config_args.config_path)
 
         generate_keys = config_args.generate_keys
+        
+        check_config = config_args.check_config
 
         obj = cls()
 
@@ -333,7 +340,9 @@ class Config(object):
         if generate_keys:
             return None
 
-        obj.invoke_all("read_arguments", args)
+        obj.invoke_all("read_arguments", args)  
+        if check_config:
+            return None
 
         return obj
 
diff --git a/tests/config/test_check.py b/tests/config/test_check.py
new file mode 100644
index 0000000000..782b8d0a49
--- /dev/null
+++ b/tests/config/test_check.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+import shutil
+import tempfile
+import yaml
+from synapse.config.homeserver import HomeServerConfig
+from tests import unittest
+
+
+class ConfigLoadingTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.dir = tempfile.mkdtemp()
+        print(self.dir)
+        self.file = os.path.join(self.dir, "homeserver.yaml")
+
+    def tearDown(self):
+        shutil.rmtree(self.dir)
+
+    def test_load_fails_if_server_name_missing(self):
+        self.generate_config()
+        self.remove_lines_containing("server_name")
+        with self.assertRaises(Exception):
+            HomeServerConfig.load_config("", ["--check-config", "-c", self.file])
+        with self.assertRaises(Exception):
+            HomeServerConfig.load_or_generate_config("", ["--check-config", "-c", self.file])
+
+    def test_generated_config_passes_check(self):
+        self.generate_config()
+
+        config = HomeServerConfig.load_config("", ["--check-config", "-c", self.file])
+        config = HomeServerConfig.load_or_generate_config("", ["--check-config", "-c", self.file])
+
+    def test_invalid_key(self):
+        self.generate_config()
+        self.add_lines_to_config([
+            "lemurs_key: 125123",
+        ])
+        config = HomeServerConfig.load_config("", ["--check-config", "-c", self.file])
+
+    def generate_config(self):
+        HomeServerConfig.load_or_generate_config("", [
+            "--generate-config",
+            "-c", self.file,
+            "--report-stats=yes",
+            "-H", "lemurs.win"
+        ])
+
+    def remove_lines_containing(self, needle):
+        with open(self.file, "r") as f:
+            contents = f.readlines()
+        contents = [l for l in contents if needle not in l]
+        with open(self.file, "w") as f:
+            f.write("".join(contents))
+
+    def add_lines_to_config(self, lines):
+        with open(self.file, "a") as f:
+            for line in lines:
+                f.write(line + "\n")