diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database, make_conn
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id="999")
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks
def test_get_appservices_by_state_none(self):
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(0, len(services))
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -339,7 +348,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn)
@defer.inlineCallbacks
@@ -349,14 +358,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=events)
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events)
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
@@ -366,8 +375,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +388,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(2, len(services))
self.assertEquals(
@@ -391,7 +400,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TestTransactionStore, self).__init__(database, db_conn, hs)
|