diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ datastores = Mock()
+ datastores.main = Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ "get_device_updates_by_remote",
+ ]
+ )
+
hs = self.setup_test_homeserver(
- datastore=(
- Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_device_updates_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )
- ),
- notifier=Mock(),
- http_client=mock_federation_client,
- keyring=mock_keyring,
+ notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
)
+ hs.datastores = datastores
+
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
from tests import unittest
from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
+ db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
+ database = hs.get_datastores().databases[0]
self.slaved_store = self.STORE_TYPE(
- Database(hs), self.hs.get_db_conn(), self.hs
+ database, make_conn(db_config, database.engine), self.hs
)
self.event_id = 0
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
- d = _sth(cleanup_func, *args, **kwargs).result
+ server = _sth(cleanup_func, *args, **kwargs)
- if isinstance(d, Failure):
- d.raiseException()
+ database = server.config.database.get_single_database()
# Make the thread pool synchronous.
- clock = d.get_clock()
- pool = d.get_db_pool()
-
- def runWithConnection(func, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs
- )
-
- def runInteraction(interaction, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- interaction,
- *args,
- **kwargs
- )
+ clock = server.get_clock()
+
+ for database in server.get_datastores().databases:
+ pool = database._db_pool
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
- if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
- return d
+
+ return server
def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- database = Database(hs)
- self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ self.store = ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- self.db_pool = hs.get_db_pool()
- self.engine = hs.database_engine
-
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- database = Database(hs)
- self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+ # We assume there is only one database in these tests
+ database = hs.get_datastores().databases[0]
+ self.db_pool = database._db_pool
+ self.engine = database.engine
+
+ db_config = hs.config.get_single_database()
+ self.store = TestTransactionStore(
+ database, make_conn(db_config, self.engine), hs
+ )
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
config.event_cache_size = 1
- config.database_config = {"name": "sqlite3"}
- engine = create_engine(config.database_config)
+ hs = TestHomeServer("test", config=config)
+
+ sqlite_config = {"name": "sqlite3"}
+ engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- hs = TestHomeServer(
- "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
- )
- self.datastore = SQLBaseStore(Database(hs), None, hs)
+ db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db._db_pool = self.db_pool
+
+ self.datastore = SQLBaseStore(db, None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore()
diff --git a/tests/utils.py b/tests/utils.py
index 585f305b9a..9f5bf40b4b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config, ["main"])
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -251,11 +254,6 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
@@ -267,21 +265,19 @@ def setup_test_homeserver(
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -320,23 +316,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
- # If we have been given an explicit datastore we probably want to mock
- # out the DataStores somehow too. This all feels a bit wrong, but then
- # mocking the stores feels wrong too.
- datastores = Mock(datastore=datastore)
-
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
- datastores=datastores,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
|