summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py45
1 files changed, 42 insertions, 3 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
 
 from prometheus_client import Histogram
 
+from twisted.enterprise import adbapi
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+def make_pool(
+    reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+    """Get the connection pool for the database.
+    """
+
+    return adbapi.ConnectionPool(
+        db_config.config["name"],
+        cp_reactor=reactor,
+        cp_openfun=engine.on_new_connection,
+        **db_config.config.get("args", {})
+    )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+    """Make a new connection to the database and return it.
+
+    Returns:
+        Connection
+    """
+
+    db_params = {
+        k: v
+        for k, v in db_config.config.get("args", {}).items()
+        if not k.startswith("cp_")
+    }
+    db_conn = engine.module.connect(**db_params)
+    engine.on_new_connection(db_conn)
+    return db_conn
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
+        self._database_config = database_config
+        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
 
         self.updates = BackgroundUpdater(hs, self)
 
@@ -234,7 +268,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.engine = hs.database_engine
+        self.engine = engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
+    def is_running(self):
+        """Is the database pool currently running
+        """
+        return self._db_pool.running
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """