summary refs log tree commit diff
path: root/synapse/app/homeserver.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xsynapse/app/homeserver.py35
1 files changed, 4 insertions, 31 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 3709cd7bf9..f29f9d702e 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -350,42 +350,15 @@ def setup(config_options):
 
     tls_context_factory = context_factory.ServerContextFactory(config)
 
-    if config.database_config:
-        with open(config.database_config, 'r') as f:
-            db_config = yaml.safe_load(f)
-    else:
-        db_config = {
-            "name": "sqlite3",
-            "args": {
-                "database": config.database_path,
-            },
-        }
-
-    db_config = {
-        k: v for k, v in db_config.items()
-    }
-
-    name = db_config.get("name", None)
-    if name == "psycopg2":
-        pass
-    elif name == "sqlite3":
-        db_config.setdefault("args", {}).update({
-            "cp_min": 1,
-            "cp_max": 1,
-            "check_same_thread": False,
-        })
-    else:
-        raise RuntimeError("Unsupported database type '%s'" % (name,))
-
-    database_engine = create_engine(name)
-    db_config["args"]["cp_openfun"] = database_engine.on_new_connection
+    database_engine = create_engine(config.database_config["name"])
+    config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
 
     hs = SynapseHomeServer(
         config.server_name,
         domain_with_port=domain_with_port,
         upload_dir=os.path.abspath("uploads"),
         db_name=config.database_path,
-        db_config=db_config,
+        db_config=config.database_config,
         tls_context_factory=tls_context_factory,
         config=config,
         content_addr=config.content_addr,
@@ -404,7 +377,7 @@ def setup(config_options):
     try:
         db_conn = database_engine.module.connect(
             **{
-                k: v for k, v in db_config.get("args", {}).items()
+                k: v for k, v in config.database_config.get("args", {}).items()
                 if not k.startswith("cp_")
             }
         )