summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-18 10:45:12 +0000
committerGitHub <noreply@github.com>2019-12-18 10:45:12 +0000
commit2284eb3a533a2df04784df08da28e67d6588a5ea (patch)
treedf75df4b3eba90e8299c8bae61157d075a0d423b /synapse/storage
parentMerge branch 'master' into develop (diff)
downloadsynapse-2284eb3a533a2df04784df08da28e67d6588a5ea.tar.xz
Add database config class (#6513)
This encapsulates config for a given database and is the way to get new
connections.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/data_stores/__init__.py40
-rw-r--r--synapse/storage/data_stores/main/client_ips.py2
-rw-r--r--synapse/storage/database.py45
-rw-r--r--synapse/storage/engines/sqlite.py16
-rw-r--r--synapse/storage/prepare_database.py7
6 files changed, 93 insertions, 19 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
         self.db = database
         self.rand = random.SystemRandom()
 
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..0983e059c0 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,50 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
+logger = logging.getLogger(__name__)
+
 
 class DataStores(object):
     """The various data stores.
 
     These are low level interfaces to physical databases.
+
+    Attributes:
+        main (DataStore)
     """
 
-    def __init__(self, main_store_class, db_conn, hs):
+    def __init__(self, main_store_class, hs):
         # Note we pass in the main store class here as workers use a different main
         # store.
-        database = Database(hs)
 
-        # Check that db is correctly configured.
-        database.engine.check_database(db_conn.cursor())
+        self.databases = []
+
+        for database_config in hs.config.database.databases:
+            db_name = database_config.name
+            engine = create_engine(database_config.config)
+
+            with make_conn(database_config, engine) as db_conn:
+                logger.info("Preparing database %r...", db_name)
+
+                engine.check_database(db_conn.cursor())
+                prepare_database(
+                    db_conn, engine, hs.config, data_stores=database_config.data_stores,
+                )
+
+                database = Database(hs, database_config, engine)
+
+                if "main" in database_config.data_stores:
+                    logger.info("Starting 'main' data store")
+                    self.main = main_store_class(database, db_conn, hs)
+
+                db_conn.commit()
 
-        prepare_database(db_conn, database.engine, config=hs.config)
+                self.databases.append(database)
 
-        self.main = main_store_class(database, db_conn, hs)
+                logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index add3037b69..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.db.is_running():
             return
 
         to_update = self._batch_row_update
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
 
 from prometheus_client import Histogram
 
+from twisted.enterprise import adbapi
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+def make_pool(
+    reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+    """Get the connection pool for the database.
+    """
+
+    return adbapi.ConnectionPool(
+        db_config.config["name"],
+        cp_reactor=reactor,
+        cp_openfun=engine.on_new_connection,
+        **db_config.config.get("args", {})
+    )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+    """Make a new connection to the database and return it.
+
+    Returns:
+        Connection
+    """
+
+    db_params = {
+        k: v
+        for k, v in db_config.config.get("args", {}).items()
+        if not k.startswith("cp_")
+    }
+    db_conn = engine.module.connect(**db_params)
+    engine.on_new_connection(db_conn)
+    return db_conn
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
+        self._database_config = database_config
+        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
 
         self.updates = BackgroundUpdater(hs, self)
 
@@ -234,7 +268,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.engine = hs.database_engine
+        self.engine = engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
+    def is_running(self):
+        """Is the database pool currently running
+        """
+        return self._db_pool.running
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ddad17dc5a..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
 import struct
 import threading
 
-from synapse.storage.prepare_database import prepare_database
-
 
 class Sqlite3Engine(object):
     single_threaded = True
@@ -25,6 +23,9 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        database = database_config.get("args", {}).get("database")
+        self._is_in_memory = database in (None, ":memory:",)
+
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
         self._current_state_group_id = None
@@ -59,7 +60,16 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
-        prepare_database(db_conn, self, config=None)
+
+        # We need to import here to avoid an import loop.
+        from synapse.storage.prepare_database import prepare_database
+
+        if self._is_in_memory:
+            # In memory databases need to be rebuilt each time. Ideally we'd
+            # reuse the same connection as we do when starting up, but that
+            # would involve using adbapi before we have started the reactor.
+            prepare_database(db_conn, self, config=None)
+
         db_conn.create_function("rank", 1, _rank)
 
     def is_deadlock(self, error):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..b4194b44ee 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main"]):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config):
         config (synapse.config.homeserver.HomeServerConfig|None):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
+        data_stores (list[str]): The name of the data stores that will be used
+            with this database. Defaults to all data stores.
     """
 
-    # For now we only have the one datastore.
-    data_stores = ["main"]
-
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)