diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index e7472e3a93..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
from tests import unittest
from tests.server import FakeTransport
@@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
- self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+ self.slaved_store = self.STORE_TYPE(
+ Database(hs), self.hs.get_db_conn(), self.hs
+ )
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
+from synapse.storage.database import Database
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = Database(hs)
+ self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(hs.get_db_conn(), hs)
+ database = Database(hs)
+ self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, db_conn, hs):
- super(TestTransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TestTransactionStore, self).__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 7915d48a9e..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from tests import unittest
@@ -59,7 +60,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
)
- self.datastore = SQLBaseStore(None, hs)
+ self.datastore = SQLBaseStore(Database(hs), None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.data_stores.main.profile import ProfileStore
from synapse.types import UserID
from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(hs.get_db_conn(), hs)
+ self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+ self.store = self.hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..ad165d7295 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
+ self.store = self.homeserver.get_datastore()
+
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
@@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure we actually joined the room
self.assertEqual(
self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0],
"$join:test.serv",
)
@@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0]
# Now lie about an event
@@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
- )
+ extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|