diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..5f2f3c7cfd 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
from textwrap import indent
+from typing import List
import yaml
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+ """Contains the connection config for a particular database.
+
+ Args:
+ name: A label for the database, used for logging.
+ db_config: The config for a particular database, as per `database`
+ section of main config. Has two fields: `name` for database
+ module name, and `args` for the args to give to the database
+ connector.
+ data_stores: The list of data stores that should be provisioned on the
+ database.
+ """
+
+ def __init__(self, name: str, db_config: dict, data_stores: List[str]):
+ if db_config["name"] not in ("sqlite3", "psycopg2"):
+ raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+ if db_config["name"] == "sqlite3":
+ db_config.setdefault("args", {}).update(
+ {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+ )
+
+ self.name = name
+ self.config = db_config
+ self.data_stores = data_stores
class DatabaseConfig(Config):
@@ -26,20 +57,14 @@ class DatabaseConfig(Config):
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
- self.database_config = config.get("database")
+ database_config = config.get("database")
- if self.database_config is None:
- self.database_config = {"name": "sqlite3", "args": {}}
+ if database_config is None:
+ database_config = {"name": "sqlite3", "args": {}}
- name = self.database_config.get("name", None)
- if name == "psycopg2":
- pass
- elif name == "sqlite3":
- self.database_config.setdefault("args", {}).update(
- {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
- )
- else:
- raise RuntimeError("Unsupported database type '%s'" % (name,))
+ self.databases = [
+ DatabaseConnectionConfig("master", database_config, data_stores=["main"])
+ ]
self.set_databasepath(config.get("database_path"))
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
+ if database_path is None:
+ return
+
if database_path != ":memory:":
database_path = self.abspath(database_path)
- if self.database_config.get("name", None) == "sqlite3":
- if database_path is not None:
- self.database_config["args"]["database"] = database_path
+
+ # We only support setting a database path if we have a single sqlite3
+ # database.
+ if len(self.databases) != 1:
+ raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+ database = self.get_single_database()
+ if database.config["name"] != "sqlite3":
+ # We don't raise here as we haven't done so before for this case.
+ logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+ return
+
+ database.config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use.",
)
+
+ def get_single_database(self) -> DatabaseConnectionConfig:
+ """Returns the database if there is only one, useful for e.g. tests
+ """
+ if len(self.databases) != 1:
+ raise Exception("More than one database exists")
+
+ return self.databases[0]
|