summary refs log tree commit diff
path: root/synapse/app/homeserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/homeserver.py')
-rwxr-xr-xsynapse/app/homeserver.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 541059b209..694a0125ad 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -17,9 +17,9 @@
 import sys
 sys.dont_write_bytecode = True
 
+from synapse.storage.engines import create_engine
 from synapse.storage import (
-    prepare_database, prepare_sqlite3_database, are_all_users_on_domain,
-    UpgradeDatabaseException,
+    are_all_users_on_domain, UpgradeDatabaseException,
 )
 
 from synapse.server import HomeServer
@@ -58,9 +58,9 @@ import os
 import re
 import resource
 import subprocess
-import sqlite3
 
-logger = logging.getLogger(__name__)
+
+logger = logging.getLogger("synapse.app.homeserver")
 
 
 class SynapseHomeServer(HomeServer):
@@ -104,13 +104,11 @@ class SynapseHomeServer(HomeServer):
             return None
 
     def build_db_pool(self):
+        name = self.db_config["name"]
+
         return adbapi.ConnectionPool(
-            "sqlite3", self.get_db_name(),
-            check_same_thread=False,
-            cp_min=1,
-            cp_max=1,
-            cp_openfun=prepare_database,  # Prepare the database for each conn
-                                          # so that :memory: sqlite works
+            name,
+            **self.db_config.get("args", {})
         )
 
     def create_resource_tree(self, redirect_root_to_web_client):
@@ -242,9 +240,9 @@ class SynapseHomeServer(HomeServer):
             )
             logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
 
-    def run_startup_checks(self, db_conn):
+    def run_startup_checks(self, db_conn, database_engine):
         all_users_native = are_all_users_on_domain(
-            db_conn, self.hostname
+            db_conn.cursor(), database_engine, self.hostname
         )
         if not all_users_native:
             sys.stderr.write(
@@ -368,15 +366,20 @@ def setup(config_options):
 
     tls_context_factory = context_factory.ServerContextFactory(config)
 
+    database_engine = create_engine(config.database_config["name"])
+    config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
+
     hs = SynapseHomeServer(
         config.server_name,
         domain_with_port=domain_with_port,
         upload_dir=os.path.abspath("uploads"),
         db_name=config.database_path,
+        db_config=config.database_config,
         tls_context_factory=tls_context_factory,
         config=config,
         content_addr=config.content_addr,
         version_string=version_string,
+        database_engine=database_engine,
     )
 
     hs.create_resource_tree(
@@ -388,10 +391,17 @@ def setup(config_options):
     logger.info("Preparing database: %s...", db_name)
 
     try:
-        with sqlite3.connect(db_name) as db_conn:
-            prepare_sqlite3_database(db_conn)
-            prepare_database(db_conn)
-            hs.run_startup_checks(db_conn)
+        db_conn = database_engine.module.connect(
+            **{
+                k: v for k, v in config.database_config.get("args", {}).items()
+                if not k.startswith("cp_")
+            }
+        )
+
+        database_engine.prepare_database(db_conn)
+        hs.run_startup_checks(db_conn, database_engine)
+
+        db_conn.commit()
     except UpgradeDatabaseException:
         sys.stderr.write(
             "\nFailed to upgrade database.\n"