summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xsynapse/app/homeserver.py35
-rw-r--r--synapse/config/database.py23
2 files changed, 25 insertions, 33 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 3709cd7bf9..f29f9d702e 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -350,42 +350,15 @@ def setup(config_options):
 
     tls_context_factory = context_factory.ServerContextFactory(config)
 
-    if config.database_config:
-        with open(config.database_config, 'r') as f:
-            db_config = yaml.safe_load(f)
-    else:
-        db_config = {
-            "name": "sqlite3",
-            "args": {
-                "database": config.database_path,
-            },
-        }
-
-    db_config = {
-        k: v for k, v in db_config.items()
-    }
-
-    name = db_config.get("name", None)
-    if name == "psycopg2":
-        pass
-    elif name == "sqlite3":
-        db_config.setdefault("args", {}).update({
-            "cp_min": 1,
-            "cp_max": 1,
-            "check_same_thread": False,
-        })
-    else:
-        raise RuntimeError("Unsupported database type '%s'" % (name,))
-
-    database_engine = create_engine(name)
-    db_config["args"]["cp_openfun"] = database_engine.on_new_connection
+    database_engine = create_engine(config.database_config["name"])
+    config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
 
     hs = SynapseHomeServer(
         config.server_name,
         domain_with_port=domain_with_port,
         upload_dir=os.path.abspath("uploads"),
         db_name=config.database_path,
-        db_config=db_config,
+        db_config=config.database_config,
         tls_context_factory=tls_context_factory,
         config=config,
         content_addr=config.content_addr,
@@ -404,7 +377,7 @@ def setup(config_options):
     try:
         db_conn = database_engine.module.connect(
             **{
-                k: v for k, v in db_config.get("args", {}).items()
+                k: v for k, v in config.database_config.get("args", {}).items()
                 if not k.startswith("cp_")
             }
         )
diff --git a/synapse/config/database.py b/synapse/config/database.py
index f3d0898c09..190d119df4 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -15,6 +15,7 @@
 
 from ._base import Config
 import os
+import yaml
 
 
 class DatabaseConfig(Config):
@@ -27,9 +28,27 @@ class DatabaseConfig(Config):
         self.event_cache_size = self.parse_size(args.event_cache_size)
 
         if args.database_config:
-            self.database_config = self.abspath(args.database_config)
+            with open(args.database_config) as f:
+                self.database_config = yaml.safe_load(f)
         else:
-            self.database_config = None
+            self.database_config = {
+                "name": "sqlite3",
+                "args": {
+                    "database": self.database_path,
+                },
+            }
+
+        name = self.database_config.get("name", None)
+        if name == "psycopg2":
+            pass
+        elif name == "sqlite3":
+            self.database_config.setdefault("args", {}).update({
+                "cp_min": 1,
+                "cp_max": 1,
+                "check_same_thread": False,
+            })
+        else:
+            raise RuntimeError("Unsupported database type '%s'" % (name,))
 
     @classmethod
     def add_arguments(cls, parser):