diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..319e2c2325 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..1b516b7976 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,11 +24,11 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TestTransactionStore, self).__init__(database, db_conn, hs)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..2efbc97c2e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,7 +9,9 @@ from tests import unittest
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ self.updates = (
+ self.hs.get_datastore().db_pool.updates
+ ) # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
@@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
self.get_success(
- store.db.simple_insert(
+ store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
)
@@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def update(progress, count):
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db.runInteraction(
+ yield store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index b589506c60..efcaeef1e7 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,7 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from tests import unittest
@@ -57,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
@@ -66,7 +66,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
+ yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
@@ -78,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
+ yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -93,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db.simple_select_one_onecol(
+ value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@@ -107,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@@ -123,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@@ -138,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db.simple_select_list(
+ ret = yield self.datastore.db_pool.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@@ -151,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@@ -166,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -181,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_delete_one(
+ yield self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..3fab5a5248 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""
# Make sure we don't clash with in progress updates.
self.assertTrue(
- self.store.db.updates._all_done, "Background updates are still ongoing"
+ self.store.db_pool.updates._all_done, "Background updates are still ongoing"
)
schema_path = os.path.join(
prepare_database.dir_path,
- "data_stores",
+ "databases",
"main",
"schema",
"delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"test_delete_forward_extremities", run_delta_file
)
)
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_soft_failed_extremities_handled_correctly(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..224ea6fd79 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
- self.store.db.simple_update(
+ self.store.db_pool.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# We should now get the correct result again
@@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertEquals(
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_alias_to_room(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (yield self.store.get_association_from_room_alias(self.alias)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ ),
)
@defer.inlineCallbacks
def test_delete_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
- room_id = yield self.store.delete_room_alias(self.alias)
+ room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (yield self.store.get_association_from_room_alias(self.alias))
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ )
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..9f8d30373b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json)
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json)
yield self.store.store_device("user", "device", "display_name")
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
- res = yield self.store.get_e2e_device_keys(
- (("user1", "device1"), ("user2", "device2"))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
for i in range(0, 20):
- self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+ self.get_success(
+ self.store.db_pool.runInteraction("insert", insert_event, i)
+ )
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 20):
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room1)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room2)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room3)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room3)
)
# Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
depth = depth_map[event_id]
- self.store.db.simple_insert_txn(
+ self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.store.db.simple_insert_many_txn(
+ self.store.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..857db071d4 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
@@ -56,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db.runInteraction(
+ counts = yield self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@@ -72,10 +76,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ yield defer.ensureDeferred(
+ self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action}
+ )
)
- yield self.store.db.runInteraction(
+ yield self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
@@ -83,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@@ -117,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db.simple_delete(
+ yield self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@@ -136,7 +142,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db.simple_insert(
+ return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..e845410dae 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db = self.store.db # type: Database
+ self.db_pool = self.store.db_pool # type: DatabasePool
- self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _create(conn):
return MultiWriterIdGenerator(
conn,
- self.db,
+ self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
@@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
sequence_name="foobar_seq",
)
- return self.get_success(self.db.runWithConnection(_create))
+ return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
def _insert(txn):
@@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -178,7 +178,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
- self.get_success(self.db.runInteraction("test", _get_next_txn))
+ self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..0155ffd04e 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield self.store.set_profile_displayname(
+ self.user.localpart, self.displayname, 1
+ )
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..e793781a26 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
)
@@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
]
self.hs.config.mau_limits_reserved_threepids = threepids
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..7458a37e54 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,9 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
@@ -43,10 +41,8 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
self.assertEquals(
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..a6012c973d 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event)
# Purge everything before this topological token
- purge = storage.purge_events.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(
+ storage.purge_events.purge_history(self.room_id, event, True)
+ )
self.pump()
self.assertEqual(self.successResultOf(purge), None)
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..41511d479f 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
@@ -341,7 +343,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -359,7 +361,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1d77b4a2d6..d07b985a8e 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id=self.u_creator.to_string(),
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id=self.u_creator.to_string(),
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -88,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.storage.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(
+ self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ )
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f282921538..17c9da4838 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -179,10 +179,10 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now let's create a room, which will insert a membership
@@ -192,7 +192,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -68,7 +70,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups_ids(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.storage.state.get_state_for_event(e5.event_id)
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(e5.event_id)
+ )
self.assertIsNotNone(e4)
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
- include_others=True,
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ ),
+ )
)
self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()}, include_others=True
+ ),
+ )
)
self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.storage.state.get_state_groups_ids(
- room_id, [e5.event_id]
+ group_ids = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
|