summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py43
1 files changed, 14 insertions, 29 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 585f305b9a..9f5bf40b4b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
 
-@defer.inlineCallbacks
 def setup_test_homeserver(
     cleanup_func,
     name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
-        config.database_config = {
+        database_config = {
             "name": "psycopg2",
             "args": {
                 "database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
             },
         }
     else:
-        config.database_config = {
+        database_config = {
             "name": "sqlite3",
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
-    db_engine = create_engine(config.database_config)
+    database = DatabaseConnectionConfig("master", database_config, ["main"])
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
@@ -251,11 +254,6 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    # we need to configure the connection pool to run the on_new_connection
-    # function, so that we can test code that uses custom sqlite functions
-    # (like rank).
-    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
     if datastore is None:
         hs = homeserverToUse(
             name,
@@ -267,21 +265,19 @@ def setup_test_homeserver(
             **kargs
         )
 
-        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
-        # date db
-        if not isinstance(db_engine, PostgresEngine):
-            db_conn = hs.get_db_conn()
-            yield prepare_database(db_conn, db_engine, config)
-            db_conn.commit()
-            db_conn.close()
+        hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
+
+        if isinstance(db_engine, PostgresEngine):
+            database = hs.get_datastores().databases[0]
 
-        else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
                 import psycopg2
 
                 # Close all the db pools
-                hs.get_db_pool().close()
+                database._db_pool.close()
 
                 dropped = False
 
@@ -320,23 +316,12 @@ def setup_test_homeserver(
                 # Register the cleanup hook
                 cleanup_func(cleanup)
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
     else:
-        # If we have been given an explicit datastore we probably want to mock
-        # out the DataStores somehow too. This all feels a bit wrong, but then
-        # mocking the stores feels wrong too.
-        datastores = Mock(datastore=datastore)
-
         hs = homeserverToUse(
             name,
-            db_pool=None,
             datastore=datastore,
-            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,