summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-18 10:45:12 +0000
committerGitHub <noreply@github.com>2019-12-18 10:45:12 +0000
commit2284eb3a533a2df04784df08da28e67d6588a5ea (patch)
treedf75df4b3eba90e8299c8bae61157d075a0d423b /tests/storage
parentMerge branch 'master' into develop (diff)
downloadsynapse-2284eb3a533a2df04784df08da28e67d6588a5ea.tar.xz
Add database config class (#6513)
This encapsulates config for a given database and is the way to get new
connections.
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_appservice.py37
-rw-r--r--tests/storage/test_base.py14
-rw-r--r--tests/storage/test_registration.py1
3 files changed, 34 insertions, 18 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        database = Database(hs)
-        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        database = Database(hs)
-        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
+
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
+        )
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         config = Mock()
         config._disable_native_upserts = True
         config.event_cache_size = 1
-        config.database_config = {"name": "sqlite3"}
-        engine = create_engine(config.database_config)
+        hs = TestHomeServer("test", config=config)
+
+        sqlite_config = {"name": "sqlite3"}
+        engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
-        hs = TestHomeServer(
-            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
-        )
 
-        self.datastore = SQLBaseStore(Database(hs), None, hs)
+        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db._db_pool = self.db_pool
+
+        self.datastore = SQLBaseStore(db, None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
-        self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()