summary refs log tree commit diff
path: root/synapse/config/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/database.py')
-rw-r--r--synapse/config/database.py76
1 files changed, 55 insertions, 21 deletions
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 87efe54645..f0611e8884 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -14,32 +14,66 @@
 # limitations under the License.
 
 from ._base import Config
-import os
 
 
 class DatabaseConfig(Config):
-    def __init__(self, args):
-        super(DatabaseConfig, self).__init__(args)
-        if args.database_path == ":memory:":
-            self.database_path = ":memory:"
+
+    def read_config(self, config):
+        self.event_cache_size = self.parse_size(
+            config.get("event_cache_size", "10K")
+        )
+
+        self.database_config = config.get("database")
+
+        if self.database_config is None:
+            self.database_config = {
+                "name": "sqlite3",
+                "args": {},
+            }
+
+        name = self.database_config.get("name", None)
+        if name == "psycopg2":
+            pass
+        elif name == "sqlite3":
+            self.database_config.setdefault("args", {}).update({
+                "cp_min": 1,
+                "cp_max": 1,
+                "check_same_thread": False,
+            })
         else:
-            self.database_path = self.abspath(args.database_path)
-        self.event_cache_size = self.parse_size(args.event_cache_size)
+            raise RuntimeError("Unsupported database type '%s'" % (name,))
+
+        self.set_databasepath(config.get("database_path"))
+
+    def default_config(self, config, config_dir_path):
+        database_path = self.abspath("homeserver.db")
+        return """\
+        # Database configuration
+        database:
+          # The database engine name
+          name: "sqlite3"
+          # Arguments to pass to the engine
+          args:
+            # Path to the database
+            database: "%(database_path)s"
 
-    @classmethod
-    def add_arguments(cls, parser):
-        super(DatabaseConfig, cls).add_arguments(parser)
+        # Number of events to cache in memory.
+        event_cache_size: "10K"
+        """ % locals()
+
+    def read_arguments(self, args):
+        self.set_databasepath(args.database_path)
+
+    def set_databasepath(self, database_path):
+        if database_path != ":memory:":
+            database_path = self.abspath(database_path)
+        if self.database_config.get("name", None) == "sqlite3":
+            if database_path is not None:
+                self.database_config["args"]["database"] = database_path
+
+    def add_arguments(self, parser):
         db_group = parser.add_argument_group("database")
         db_group.add_argument(
-            "-d", "--database-path", default="homeserver.db",
-            help="The database name."
+            "-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
+            help="The path to a sqlite database to use."
         )
-        db_group.add_argument(
-            "--event-cache-size", default="100K",
-            help="Number of events to cache in memory."
-        )
-
-    @classmethod
-    def generate_config(cls, args, config_dir_path):
-        super(DatabaseConfig, cls).generate_config(args, config_dir_path)
-        args.database_path = os.path.abspath(args.database_path)