summary refs log tree commit diff
path: root/tests/storage/test_appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_appservice.py')
-rw-r--r--tests/storage/test_appservice.py47
1 files changed, 28 insertions, 19 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database, make_conn
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_appservice_state_none(self):
         service = Mock(id="999")
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(None, state)
 
     @defer.inlineCallbacks
     def test_get_appservice_state_up(self):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
         service = Mock(id=self.as_list[0]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.UP, state)
 
     @defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         service = Mock(id=self.as_list[1]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.DOWN, state)
 
     @defer.inlineCallbacks
     def test_get_appservices_by_state_none(self):
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(0, len(services))
 
     @defer.inlineCallbacks
     def test_set_appservices_state_down(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_multiple_up(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -339,7 +348,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_get_oldest_unsent_txn_none(self):
         service = Mock(id=self.as_list[0]["id"])
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(None, txn)
 
     @defer.inlineCallbacks
@@ -349,14 +358,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store.get_events_as_list = Mock(return_value=events)
+        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)
         yield self._insert_txn(service.id, 11, other_events)
         yield self._insert_txn(service.id, 12, other_events)
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(service, txn.service)
         self.assertEquals(10, txn.id)
         self.assertEquals(events, txn.events)
@@ -366,8 +375,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(1, len(services))
         self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +388,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(2, len(services))
         self.assertEquals(
@@ -391,7 +400,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(TestTransactionStore, self).__init__(database, db_conn, hs)