summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py44
-rw-r--r--synapse/config/database.py9
2 files changed, 42 insertions, 11 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 15c454af76..a2fca2e024 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -61,6 +61,7 @@ import resource
 import subprocess
 import sqlite3
 import syweb
+import yaml
 
 logger = logging.getLogger(__name__)
 
@@ -108,14 +109,14 @@ class SynapseHomeServer(HomeServer):
             return None
 
     def build_db_pool(self):
-        return adbapi.ConnectionPool(
-            "sqlite3", self.get_db_name(),
-            check_same_thread=False,
-            cp_min=1,
-            cp_max=1,
-            cp_openfun=prepare_database,  # Prepare the database for each conn
-                                          # so that :memory: sqlite works
-        )
+        name = self.db_config.pop("name", None)
+        if name == "MySQLdb":
+            return adbapi.ConnectionPool(
+                name,
+                **self.db_config
+            )
+
+        raise RuntimeError("Unsupported database type")
 
     def create_resource_tree(self, redirect_root_to_web_client):
         """Create the resource tree for this Home Server.
@@ -357,11 +358,29 @@ def setup(config_options):
 
     tls_context_factory = context_factory.ServerContextFactory(config)
 
+    if config.database_config:
+        with open(config.database_config, 'r') as f:
+            db_config = yaml.safe_load(f)
+
+        name = db_config.get("name", None)
+        if name == "MySQLdb":
+            db_config.update({
+                "sql_mode": "TRADITIONAL",
+                "charset": "utf8",
+                "use_unicode": True,
+            })
+    else:
+        db_config = {
+            "name": "sqlite3",
+            "database": config.database_path,
+        }
+
     hs = SynapseHomeServer(
         config.server_name,
         domain_with_port=domain_with_port,
         upload_dir=os.path.abspath("uploads"),
         db_name=config.database_path,
+        db_config=db_config,
         tls_context_factory=tls_context_factory,
         config=config,
         content_addr=config.content_addr,
@@ -377,9 +396,12 @@ def setup(config_options):
     logger.info("Preparing database: %s...", db_name)
 
     try:
-        with sqlite3.connect(db_name) as db_conn:
-            prepare_sqlite3_database(db_conn)
-            prepare_database(db_conn)
+        # with sqlite3.connect(db_name) as db_conn:
+        #     prepare_sqlite3_database(db_conn)
+        #     prepare_database(db_conn)
+        import MySQLdb
+        db_conn = MySQLdb.connect(**db_config)
+        prepare_database(db_conn)
     except UpgradeDatabaseException:
         sys.stderr.write(
             "\nFailed to upgrade database.\n"
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 87efe54645..8dc9873f8c 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -26,6 +26,11 @@ class DatabaseConfig(Config):
             self.database_path = self.abspath(args.database_path)
         self.event_cache_size = self.parse_size(args.event_cache_size)
 
+        if args.database_config:
+            self.database_config = self.abspath(args.database_config)
+        else:
+            self.database_config = None
+
     @classmethod
     def add_arguments(cls, parser):
         super(DatabaseConfig, cls).add_arguments(parser)
@@ -38,6 +43,10 @@ class DatabaseConfig(Config):
             "--event-cache-size", default="100K",
             help="Number of events to cache in memory."
         )
+        db_group.add_argument(
+            "--database-config", default=None,
+            help="Location of the database configuration file."
+        )
 
     @classmethod
     def generate_config(cls, args, config_dir_path):