summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6513.misc1
-rwxr-xr-xscripts-dev/update_database9
-rwxr-xr-xscripts/synapse_port_db58
-rw-r--r--synapse/config/database.py78
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/server.py41
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/data_stores/__init__.py40
-rw-r--r--synapse/storage/data_stores/main/client_ips.py2
-rw-r--r--synapse/storage/database.py45
-rw-r--r--synapse/storage/engines/sqlite.py16
-rw-r--r--synapse/storage/prepare_database.py7
-rw-r--r--tests/handlers/test_typing.py39
-rw-r--r--tests/replication/slave/storage/_base.py6
-rw-r--r--tests/server.py55
-rw-r--r--tests/storage/test_appservice.py37
-rw-r--r--tests/storage/test_base.py14
-rw-r--r--tests/storage/test_registration.py1
-rw-r--r--tests/utils.py43
19 files changed, 287 insertions, 209 deletions
diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc
new file mode 100644
index 0000000000..36700f5657
--- /dev/null
+++ b/changelog.d/6513.misc
@@ -0,0 +1 @@
+Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`.
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 23017c21f8..1d62f0403a 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.storage.prepare_database import prepare_database
 
 logger = logging.getLogger("update_database")
 
@@ -77,12 +76,8 @@ if __name__ == "__main__":
     # Instantiate and initialise the homeserver object.
     hs = MockHomeserver(config)
 
-    db_conn = hs.get_db_conn()
-    # Update the database to the latest schema.
-    prepare_database(db_conn, hs.database_engine, config=config)
-    db_conn.commit()
-
-    # setup instantiates the store within the homeserver object.
+    # Setup instantiates the store within the homeserver object and updates the
+    # DB.
     hs.setup()
     store = hs.get_datastore()
 
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index e393a9b2f7..5b5368988c 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -30,6 +30,7 @@ import yaml
 from twisted.enterprise import adbapi
 from twisted.internet import defer, reactor
 
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.logging.context import PreserveLoggingContext
 from synapse.storage._base import LoggingTransaction
@@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
 from synapse.storage.data_stores.main.user_directory import (
     UserDirectoryBackgroundUpdateStore,
 )
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util import Clock
@@ -165,23 +166,17 @@ class Store(
 
 
 class MockHomeserver:
-    def __init__(self, config, database_engine, db_conn, db_pool):
-        self.database_engine = database_engine
-        self.db_conn = db_conn
-        self.db_pool = db_pool
+    def __init__(self, config):
         self.clock = Clock(reactor)
         self.config = config
         self.hostname = config.server_name
 
-    def get_db_conn(self):
-        return self.db_conn
-
-    def get_db_pool(self):
-        return self.db_pool
-
     def get_clock(self):
         return self.clock
 
+    def get_reactor(self):
+        return reactor
+
 
 class Porter(object):
     def __init__(self, **kwargs):
@@ -445,45 +440,36 @@ class Porter(object):
             else:
                 return
 
-    def setup_db(self, db_config, database_engine):
-        db_conn = database_engine.module.connect(
-            **{
-                k: v
-                for k, v in db_config.get("args", {}).items()
-                if not k.startswith("cp_")
-            }
-        )
-
-        prepare_database(db_conn, database_engine, config=None)
+    def setup_db(self, db_config: DatabaseConnectionConfig, engine):
+        db_conn = make_conn(db_config, engine)
+        prepare_database(db_conn, engine, config=None)
 
         db_conn.commit()
 
         return db_conn
 
     @defer.inlineCallbacks
-    def build_db_store(self, config):
+    def build_db_store(self, db_config: DatabaseConnectionConfig):
         """Builds and returns a database store using the provided configuration.
 
         Args:
-            config: The database configuration, i.e. a dict following the structure of
-                the "database" section of Synapse's configuration file.
+            config: The database configuration
 
         Returns:
             The built Store object.
         """
-        engine = create_engine(config)
-
-        self.progress.set_state("Preparing %s" % config["name"])
-        conn = self.setup_db(config, engine)
+        self.progress.set_state("Preparing %s" % db_config.config["name"])
 
-        db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
+        engine = create_engine(db_config.config)
+        conn = self.setup_db(db_config, engine)
 
-        hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
+        hs = MockHomeserver(self.hs_config)
 
-        store = Store(Database(hs), conn, hs)
+        store = Store(Database(hs, db_config, engine), conn, hs)
 
         yield store.db.runInteraction(
-            "%s_engine.check_database" % config["name"], engine.check_database,
+            "%s_engine.check_database" % db_config.config["name"],
+            engine.check_database,
         )
 
         return store
@@ -509,7 +495,11 @@ class Porter(object):
     @defer.inlineCallbacks
     def run(self):
         try:
-            self.sqlite_store = yield self.build_db_store(self.sqlite_config)
+            self.sqlite_store = yield self.build_db_store(
+                DatabaseConnectionConfig(
+                    "master", self.sqlite_config, data_stores=["main"]
+                )
+            )
 
             # Check if all background updates are done, abort if not.
             updates_complete = (
@@ -524,7 +514,7 @@ class Porter(object):
                 defer.returnValue(None)
 
             self.postgres_store = yield self.build_db_store(
-                self.hs_config.database_config
+                self.hs_config.get_single_database()
             )
 
             yield self.run_background_updates_on_postgres()
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..5f2f3c7cfd 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,43 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 import os
 from textwrap import indent
+from typing import List
 
 import yaml
 
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+    """Contains the connection config for a particular database.
+
+    Args:
+        name: A label for the database, used for logging.
+        db_config: The config for a particular database, as per `database`
+            section of main config. Has two fields: `name` for database
+            module name, and `args` for the args to give to the database
+            connector.
+        data_stores: The list of data stores that should be provisioned on the
+            database.
+    """
+
+    def __init__(self, name: str, db_config: dict, data_stores: List[str]):
+        if db_config["name"] not in ("sqlite3", "psycopg2"):
+            raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+        if db_config["name"] == "sqlite3":
+            db_config.setdefault("args", {}).update(
+                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+            )
+
+        self.name = name
+        self.config = db_config
+        self.data_stores = data_stores
 
 
 class DatabaseConfig(Config):
@@ -26,20 +57,14 @@ class DatabaseConfig(Config):
     def read_config(self, config, **kwargs):
         self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
 
-        self.database_config = config.get("database")
+        database_config = config.get("database")
 
-        if self.database_config is None:
-            self.database_config = {"name": "sqlite3", "args": {}}
+        if database_config is None:
+            database_config = {"name": "sqlite3", "args": {}}
 
-        name = self.database_config.get("name", None)
-        if name == "psycopg2":
-            pass
-        elif name == "sqlite3":
-            self.database_config.setdefault("args", {}).update(
-                {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
-            )
-        else:
-            raise RuntimeError("Unsupported database type '%s'" % (name,))
+        self.databases = [
+            DatabaseConnectionConfig("master", database_config, data_stores=["main"])
+        ]
 
         self.set_databasepath(config.get("database_path"))
 
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
         self.set_databasepath(args.database_path)
 
     def set_databasepath(self, database_path):
+        if database_path is None:
+            return
+
         if database_path != ":memory:":
             database_path = self.abspath(database_path)
-        if self.database_config.get("name", None) == "sqlite3":
-            if database_path is not None:
-                self.database_config["args"]["database"] = database_path
+
+        # We only support setting a database path if we have a single sqlite3
+        # database.
+        if len(self.databases) != 1:
+            raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+        database = self.get_single_database()
+        if database.config["name"] != "sqlite3":
+            # We don't raise here as we haven't done so before for this case.
+            logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+            return
+
+        database.config["args"]["database"] = database_path
 
     @staticmethod
     def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
             metavar="SQLITE_DATABASE_PATH",
             help="The path to a sqlite database to use.",
         )
+
+    def get_single_database(self) -> DatabaseConnectionConfig:
+        """Returns the database if there is only one, useful for e.g. tests
+        """
+        if len(self.databases) != 1:
+            raise Exception("More than one database exists")
+
+        return self.databases[0]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index eda15bc623..240c4add12 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -230,7 +230,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.store.database.is_running():
             return
 
         logger.info(
diff --git a/synapse/server.py b/synapse/server.py
index 5021068ce0..7926867b77 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,6 @@ import abc
 import logging
 import os
 
-from twisted.enterprise import adbapi
 from twisted.mail.smtp import sendmail
 from twisted.web.client import BrowserLikePolicyForHTTPS
 
@@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import (
 )
 from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import DataStores, Storage
-from synapse.storage.engines import create_engine
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
@@ -134,7 +132,6 @@ class HomeServer(object):
 
     DEPENDENCIES = [
         "http_client",
-        "db_pool",
         "federation_client",
         "federation_server",
         "handlers",
@@ -233,12 +230,6 @@ class HomeServer(object):
         self.admin_redaction_ratelimiter = Ratelimiter()
         self.registration_ratelimiter = Ratelimiter()
 
-        self.database_engine = create_engine(config.database_config)
-        config.database_config.setdefault("args", {})[
-            "cp_openfun"
-        ] = self.database_engine.on_new_connection
-        self.db_config = config.database_config
-
         self.datastores = None
 
         # Other kwargs are explicit dependencies
@@ -247,10 +238,8 @@ class HomeServer(object):
 
     def setup(self):
         logger.info("Setting up.")
-        with self.get_db_conn() as conn:
-            self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
-            conn.commit()
         self.start_time = int(self.get_clock().time())
+        self.datastores = DataStores(self.DATASTORE_CLASS, self)
         logger.info("Finished setting up.")
 
     def setup_master(self):
@@ -284,6 +273,9 @@ class HomeServer(object):
     def get_datastore(self):
         return self.datastores.main
 
+    def get_datastores(self):
+        return self.datastores
+
     def get_config(self):
         return self.config
 
@@ -433,31 +425,6 @@ class HomeServer(object):
         )
         return MatrixFederationHttpClient(self, tls_client_options_factory)
 
-    def build_db_pool(self):
-        name = self.db_config["name"]
-
-        return adbapi.ConnectionPool(
-            name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
-        )
-
-    def get_db_conn(self, run_new_connection=True):
-        """Makes a new connection to the database, skipping the db pool
-
-        Returns:
-            Connection: a connection object implementing the PEP-249 spec
-        """
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v
-            for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def build_media_repository_resource(self):
         # build the media repo resource. This indirects through the HomeServer
         # to ensure that we only have a single instance of
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
         self.db = database
         self.rand = random.SystemRandom()
 
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..0983e059c0 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,50 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
+logger = logging.getLogger(__name__)
+
 
 class DataStores(object):
     """The various data stores.
 
     These are low level interfaces to physical databases.
+
+    Attributes:
+        main (DataStore)
     """
 
-    def __init__(self, main_store_class, db_conn, hs):
+    def __init__(self, main_store_class, hs):
         # Note we pass in the main store class here as workers use a different main
         # store.
-        database = Database(hs)
 
-        # Check that db is correctly configured.
-        database.engine.check_database(db_conn.cursor())
+        self.databases = []
+
+        for database_config in hs.config.database.databases:
+            db_name = database_config.name
+            engine = create_engine(database_config.config)
+
+            with make_conn(database_config, engine) as db_conn:
+                logger.info("Preparing database %r...", db_name)
+
+                engine.check_database(db_conn.cursor())
+                prepare_database(
+                    db_conn, engine, hs.config, data_stores=database_config.data_stores,
+                )
+
+                database = Database(hs, database_config, engine)
+
+                if "main" in database_config.data_stores:
+                    logger.info("Starting 'main' data store")
+                    self.main = main_store_class(database, db_conn, hs)
+
+                db_conn.commit()
 
-        prepare_database(db_conn, database.engine, config=hs.config)
+                self.databases.append(database)
 
-        self.main = main_store_class(database, db_conn, hs)
+                logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index add3037b69..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
-        if not self.hs.get_db_pool().running:
+        if not self.db.is_running():
             return
 
         to_update = self._batch_row_update
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
 
 from prometheus_client import Histogram
 
+from twisted.enterprise import adbapi
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+def make_pool(
+    reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+    """Get the connection pool for the database.
+    """
+
+    return adbapi.ConnectionPool(
+        db_config.config["name"],
+        cp_reactor=reactor,
+        cp_openfun=engine.on_new_connection,
+        **db_config.config.get("args", {})
+    )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+    """Make a new connection to the database and return it.
+
+    Returns:
+        Connection
+    """
+
+    db_params = {
+        k: v
+        for k, v in db_config.config.get("args", {}).items()
+        if not k.startswith("cp_")
+    }
+    db_conn = engine.module.connect(**db_params)
+    engine.on_new_connection(db_conn)
+    return db_conn
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
+        self._database_config = database_config
+        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
 
         self.updates = BackgroundUpdater(hs, self)
 
@@ -234,7 +268,7 @@ class Database(object):
         #   to watch it
         self._txn_perf_counters = PerformanceCounters()
 
-        self.engine = hs.database_engine
+        self.engine = engine
 
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
+    def is_running(self):
+        """Is the database pool currently running
+        """
+        return self._db_pool.running
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ddad17dc5a..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
 import struct
 import threading
 
-from synapse.storage.prepare_database import prepare_database
-
 
 class Sqlite3Engine(object):
     single_threaded = True
@@ -25,6 +23,9 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        database = database_config.get("args", {}).get("database")
+        self._is_in_memory = database in (None, ":memory:",)
+
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
         self._current_state_group_id = None
@@ -59,7 +60,16 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
-        prepare_database(db_conn, self, config=None)
+
+        # We need to import here to avoid an import loop.
+        from synapse.storage.prepare_database import prepare_database
+
+        if self._is_in_memory:
+            # In memory databases need to be rebuilt each time. Ideally we'd
+            # reuse the same connection as we do when starting up, but that
+            # would involve using adbapi before we have started the reactor.
+            prepare_database(db_conn, self, config=None)
+
         db_conn.create_function("rank", 1, _rank)
 
     def is_deadlock(self, error):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..b4194b44ee 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main"]):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config):
         config (synapse.config.homeserver.HomeServerConfig|None):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
+        data_stores (list[str]): The name of the data stores that will be used
+            with this database. Defaults to all data stores.
     """
 
-    # For now we only have the one datastore.
-    data_stores = ["main"]
-
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_federation_client = Mock(spec=["put_json"])
         mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
 
+        datastores = Mock()
+        datastores.main = Mock(
+            spec=[
+                # Bits that Federation needs
+                "prep_send_transaction",
+                "delivered_txn",
+                "get_received_txn_response",
+                "set_received_txn_response",
+                "get_destination_retry_timings",
+                "get_devices_by_remote",
+                # Bits that user_directory needs
+                "get_user_directory_stream_pos",
+                "get_current_state_deltas",
+                "get_device_updates_by_remote",
+            ]
+        )
+
         hs = self.setup_test_homeserver(
-            datastore=(
-                Mock(
-                    spec=[
-                        # Bits that Federation needs
-                        "prep_send_transaction",
-                        "delivered_txn",
-                        "get_received_txn_response",
-                        "set_received_txn_response",
-                        "get_destination_retry_timings",
-                        "get_device_updates_by_remote",
-                        # Bits that user_directory needs
-                        "get_user_directory_stream_pos",
-                        "get_current_state_deltas",
-                    ]
-                )
-            ),
-            notifier=Mock(),
-            http_client=mock_federation_client,
-            keyring=mock_keyring,
+            notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
         )
 
+        hs.datastores = datastores
+
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
 
+        db_config = hs.config.database.get_single_database()
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
+        database = hs.get_datastores().databases[0]
         self.slaved_store = self.STORE_TYPE(
-            Database(hs), self.hs.get_db_conn(), self.hs
+            database, make_conn(db_config, database.engine), self.hs
         )
         self.event_id = 0
 
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     Set up a synchronous test server, driven by the reactor used by
     the homeserver.
     """
-    d = _sth(cleanup_func, *args, **kwargs).result
+    server = _sth(cleanup_func, *args, **kwargs)
 
-    if isinstance(d, Failure):
-        d.raiseException()
+    database = server.config.database.get_single_database()
 
     # Make the thread pool synchronous.
-    clock = d.get_clock()
-    pool = d.get_db_pool()
-
-    def runWithConnection(func, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runWithConnection,
-            func,
-            *args,
-            **kwargs
-        )
-
-    def runInteraction(interaction, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runInteraction,
-            interaction,
-            *args,
-            **kwargs
-        )
+    clock = server.get_clock()
+
+    for database in server.get_datastores().databases:
+        pool = database._db_pool
+
+        def runWithConnection(func, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runWithConnection,
+                func,
+                *args,
+                **kwargs
+            )
+
+        def runInteraction(interaction, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runInteraction,
+                interaction,
+                *args,
+                **kwargs
+            )
 
-    if pool:
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
-    return d
+
+    return server
 
 
 def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        database = Database(hs)
-        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        database = Database(hs)
-        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
+
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
+        )
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         config = Mock()
         config._disable_native_upserts = True
         config.event_cache_size = 1
-        config.database_config = {"name": "sqlite3"}
-        engine = create_engine(config.database_config)
+        hs = TestHomeServer("test", config=config)
+
+        sqlite_config = {"name": "sqlite3"}
+        engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
-        hs = TestHomeServer(
-            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
-        )
 
-        self.datastore = SQLBaseStore(Database(hs), None, hs)
+        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db._db_pool = self.db_pool
+
+        self.datastore = SQLBaseStore(db, None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
-        self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()
 
diff --git a/tests/utils.py b/tests/utils.py
index 585f305b9a..9f5bf40b4b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
 
-@defer.inlineCallbacks
 def setup_test_homeserver(
     cleanup_func,
     name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
-        config.database_config = {
+        database_config = {
             "name": "psycopg2",
             "args": {
                 "database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
             },
         }
     else:
-        config.database_config = {
+        database_config = {
             "name": "sqlite3",
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
-    db_engine = create_engine(config.database_config)
+    database = DatabaseConnectionConfig("master", database_config, ["main"])
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
@@ -251,11 +254,6 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    # we need to configure the connection pool to run the on_new_connection
-    # function, so that we can test code that uses custom sqlite functions
-    # (like rank).
-    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
     if datastore is None:
         hs = homeserverToUse(
             name,
@@ -267,21 +265,19 @@ def setup_test_homeserver(
             **kargs
         )
 
-        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
-        # date db
-        if not isinstance(db_engine, PostgresEngine):
-            db_conn = hs.get_db_conn()
-            yield prepare_database(db_conn, db_engine, config)
-            db_conn.commit()
-            db_conn.close()
+        hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
+
+        if isinstance(db_engine, PostgresEngine):
+            database = hs.get_datastores().databases[0]
 
-        else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
                 import psycopg2
 
                 # Close all the db pools
-                hs.get_db_pool().close()
+                database._db_pool.close()
 
                 dropped = False
 
@@ -320,23 +316,12 @@ def setup_test_homeserver(
                 # Register the cleanup hook
                 cleanup_func(cleanup)
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
     else:
-        # If we have been given an explicit datastore we probably want to mock
-        # out the DataStores somehow too. This all feels a bit wrong, but then
-        # mocking the stores feels wrong too.
-        datastores = Mock(datastore=datastore)
-
         hs = homeserverToUse(
             name,
-            db_pool=None,
             datastore=datastore,
-            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,