summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-09 13:53:21 +0000
committerGitHub <noreply@github.com>2019-12-09 13:53:21 +0000
commit30e9adf32f7db9aa44693876d33692bf4590d363 (patch)
tree9b83e1dc2eeefaafe8eca4e5b95cffdf8572b346 /tests
parentMerge pull request #6493 from matrix-org/erikj/invite_state_config (diff)
parentFix comment (diff)
downloadsynapse-30e9adf32f7db9aa44693876d33692bf4590d363.tar.xz
Merge pull request #6487 from matrix-org/erikj/pass_in_db
Pass in Database object to data stores.
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/_base.py5
-rw-r--r--tests/storage/test_appservice.py17
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_profile.py3
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_federation.py16
6 files changed, 23 insertions, 25 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index e7472e3a93..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
-        self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+        self.slaved_store = self.STORE_TYPE(
+            Database(hs), self.hs.get_db_conn(), self.hs
+        )
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 7915d48a9e..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
         )
 
-        self.datastore = SQLBaseStore(None, hs)
+        self.datastore = SQLBaseStore(Database(hs), None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(hs.get_db_conn(), hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+        self.store = self.hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..ad165d7295 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]
 
+        self.store = self.homeserver.get_datastore()
+
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
@@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                maybeDeferred(
-                    self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                    self.room_id,
-                )
+                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
             )[0],
             "$join:test.serv",
         )
@@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Now lie about an event
@@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(
-            self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
-        )
+        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")