summary refs log tree commit diff
path: root/synapse/config
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config')
-rw-r--r--synapse/config/database.py78
1 files changed, 62 insertions, 16 deletions
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..5f2f3c7cfd 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,43 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 import os
 from textwrap import indent
+from typing import List
 
 import yaml
 
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+    """Contains the connection config for a particular database.
+
+    Args:
+        name: A label for the database, used for logging.
+        db_config: The config for a particular database, as per `database`
+            section of main config. Has two fields: `name` for database
+            module name, and `args` for the args to give to the database
+            connector.
+        data_stores: The list of data stores that should be provisioned on the
+            database.
+    """
+
+    def __init__(self, name: str, db_config: dict, data_stores: List[str]):
+        if db_config["name"] not in ("sqlite3", "psycopg2"):
+            raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+        if db_config["name"] == "sqlite3":
+            db_config.setdefault("args", {}).update(
+                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+            )
+
+        self.name = name
+        self.config = db_config
+        self.data_stores = data_stores
 
 
 class DatabaseConfig(Config):
@@ -26,20 +57,14 @@ class DatabaseConfig(Config):
     def read_config(self, config, **kwargs):
         self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
 
-        self.database_config = config.get("database")
+        database_config = config.get("database")
 
-        if self.database_config is None:
-            self.database_config = {"name": "sqlite3", "args": {}}
+        if database_config is None:
+            database_config = {"name": "sqlite3", "args": {}}
 
-        name = self.database_config.get("name", None)
-        if name == "psycopg2":
-            pass
-        elif name == "sqlite3":
-            self.database_config.setdefault("args", {}).update(
-                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
-            )
-        else:
-            raise RuntimeError("Unsupported database type '%s'" % (name,))
+        self.databases = [
+            DatabaseConnectionConfig("master", database_config, data_stores=["main"])
+        ]
 
         self.set_databasepath(config.get("database_path"))
 
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
         self.set_databasepath(args.database_path)
 
     def set_databasepath(self, database_path):
+        if database_path is None:
+            return
+
         if database_path != ":memory:":
             database_path = self.abspath(database_path)
-        if self.database_config.get("name", None) == "sqlite3":
-            if database_path is not None:
-                self.database_config["args"]["database"] = database_path
+
+        # We only support setting a database path if we have a single sqlite3
+        # database.
+        if len(self.databases) != 1:
+            raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+        database = self.get_single_database()
+        if database.config["name"] != "sqlite3":
+            # We don't raise here as we haven't done so before for this case.
+            logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+            return
+
+        database.config["args"]["database"] = database_path
 
     @staticmethod
     def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
             metavar="SQLITE_DATABASE_PATH",
             help="The path to a sqlite database to use.",
         )
+
+    def get_single_database(self) -> DatabaseConnectionConfig:
+        """Returns the database if there is only one, useful for e.g. tests
+        """
+        if len(self.databases) != 1:
+            raise Exception("More than one database exists")
+
+        return self.databases[0]