diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
from prometheus_client import Histogram
+from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
}
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
self.updates = BackgroundUpdater(hs, self)
@@ -234,7 +268,7 @@ class Database(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self.engine = hs.database_engine
+ self.engine = engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
self._check_safe_to_upsert,
)
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
|