summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-18 10:45:12 +0000
committerGitHub <noreply@github.com>2019-12-18 10:45:12 +0000
commit2284eb3a533a2df04784df08da28e67d6588a5ea (patch)
treedf75df4b3eba90e8299c8bae61157d075a0d423b /tests
parentMerge branch 'master' into develop (diff)
downloadsynapse-2284eb3a533a2df04784df08da28e67d6588a5ea.tar.xz
Add database config class (#6513)
This encapsulates config for a given database and is the way to get new
connections.
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_typing.py39
-rw-r--r--tests/replication/slave/storage/_base.py6
-rw-r--r--tests/server.py55
-rw-r--r--tests/storage/test_appservice.py37
-rw-r--r--tests/storage/test_base.py14
-rw-r--r--tests/storage/test_registration.py1
-rw-r--r--tests/utils.py43
7 files changed, 100 insertions, 95 deletions
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_federation_client = Mock(spec=["put_json"])
         mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
 
+        datastores = Mock()
+        datastores.main = Mock(
+            spec=[
+                # Bits that Federation needs
+                "prep_send_transaction",
+                "delivered_txn",
+                "get_received_txn_response",
+                "set_received_txn_response",
+                "get_destination_retry_timings",
+                "get_devices_by_remote",
+                # Bits that user_directory needs
+                "get_user_directory_stream_pos",
+                "get_current_state_deltas",
+                "get_device_updates_by_remote",
+            ]
+        )
+
         hs = self.setup_test_homeserver(
-            datastore=(
-                Mock(
-                    spec=[
-                        # Bits that Federation needs
-                        "prep_send_transaction",
-                        "delivered_txn",
-                        "get_received_txn_response",
-                        "set_received_txn_response",
-                        "get_destination_retry_timings",
-                        "get_device_updates_by_remote",
-                        # Bits that user_directory needs
-                        "get_user_directory_stream_pos",
-                        "get_current_state_deltas",
-                    ]
-                )
-            ),
-            notifier=Mock(),
-            http_client=mock_federation_client,
-            keyring=mock_keyring,
+            notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
         )
 
+        hs.datastores = datastores
+
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
 
+        db_config = hs.config.database.get_single_database()
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
+        database = hs.get_datastores().databases[0]
         self.slaved_store = self.STORE_TYPE(
-            Database(hs), self.hs.get_db_conn(), self.hs
+            database, make_conn(db_config, database.engine), self.hs
         )
         self.event_id = 0
 
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     Set up a synchronous test server, driven by the reactor used by
     the homeserver.
     """
-    d = _sth(cleanup_func, *args, **kwargs).result
+    server = _sth(cleanup_func, *args, **kwargs)
 
-    if isinstance(d, Failure):
-        d.raiseException()
+    database = server.config.database.get_single_database()
 
     # Make the thread pool synchronous.
-    clock = d.get_clock()
-    pool = d.get_db_pool()
-
-    def runWithConnection(func, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runWithConnection,
-            func,
-            *args,
-            **kwargs
-        )
-
-    def runInteraction(interaction, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runInteraction,
-            interaction,
-            *args,
-            **kwargs
-        )
+    clock = server.get_clock()
+
+    for database in server.get_datastores().databases:
+        pool = database._db_pool
+
+        def runWithConnection(func, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runWithConnection,
+                func,
+                *args,
+                **kwargs
+            )
+
+        def runInteraction(interaction, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runInteraction,
+                interaction,
+                *args,
+                **kwargs
+            )
 
-    if pool:
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
-    return d
+
+    return server
 
 
 def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        database = Database(hs)
-        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        database = Database(hs)
-        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
+
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
+        )
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         config = Mock()
         config._disable_native_upserts = True
         config.event_cache_size = 1
-        config.database_config = {"name": "sqlite3"}
-        engine = create_engine(config.database_config)
+        hs = TestHomeServer("test", config=config)
+
+        sqlite_config = {"name": "sqlite3"}
+        engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
-        hs = TestHomeServer(
-            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
-        )
 
-        self.datastore = SQLBaseStore(Database(hs), None, hs)
+        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db._db_pool = self.db_pool
+
+        self.datastore = SQLBaseStore(db, None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
-        self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()
 
diff --git a/tests/utils.py b/tests/utils.py
index 585f305b9a..9f5bf40b4b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
 
-@defer.inlineCallbacks
 def setup_test_homeserver(
     cleanup_func,
     name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
-        config.database_config = {
+        database_config = {
             "name": "psycopg2",
             "args": {
                 "database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
             },
         }
     else:
-        config.database_config = {
+        database_config = {
             "name": "sqlite3",
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
-    db_engine = create_engine(config.database_config)
+    database = DatabaseConnectionConfig("master", database_config, ["main"])
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
@@ -251,11 +254,6 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    # we need to configure the connection pool to run the on_new_connection
-    # function, so that we can test code that uses custom sqlite functions
-    # (like rank).
-    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
     if datastore is None:
         hs = homeserverToUse(
             name,
@@ -267,21 +265,19 @@ def setup_test_homeserver(
             **kargs
         )
 
-        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
-        # date db
-        if not isinstance(db_engine, PostgresEngine):
-            db_conn = hs.get_db_conn()
-            yield prepare_database(db_conn, db_engine, config)
-            db_conn.commit()
-            db_conn.close()
+        hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
+
+        if isinstance(db_engine, PostgresEngine):
+            database = hs.get_datastores().databases[0]
 
-        else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
                 import psycopg2
 
                 # Close all the db pools
-                hs.get_db_pool().close()
+                database._db_pool.close()
 
                 dropped = False
 
@@ -320,23 +316,12 @@ def setup_test_homeserver(
                 # Register the cleanup hook
                 cleanup_func(cleanup)
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
     else:
-        # If we have been given an explicit datastore we probably want to mock
-        # out the DataStores somehow too. This all feels a bit wrong, but then
-        # mocking the stores feels wrong too.
-        datastores = Mock(datastore=datastore)
-
         hs = homeserverToUse(
             name,
-            db_pool=None,
             datastore=datastore,
-            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,