diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 675959c56c..77376b348e 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,6 +15,7 @@
from tests import unittest
from twisted.internet import defer
+from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer
from synapse.storage.appservice import (
@@ -33,14 +34,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- db_pool = SQLiteMemoryDbPool()
- yield db_pool.prepare()
- hs = HomeServer(
- "test", db_pool=db_pool, clock=MockClock(),
- config=Mock(
- app_service_config_files=self.as_yaml_files
- )
+ config = Mock(
+ app_service_config_files=self.as_yaml_files
)
+ hs = yield setup_test_homeserver(config=config)
self.as_token = "token1"
self.as_url = "some_url"
@@ -102,8 +99,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- self.db_pool = SQLiteMemoryDbPool()
- yield self.db_pool.prepare()
+
+ config = Mock(
+ app_service_config_files=self.as_yaml_files
+ )
+ hs = yield setup_test_homeserver(config=config)
+ self.db_pool = hs.get_db_pool()
+
self.as_list = [
{
"token": "token1",
@@ -129,11 +131,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
for s in self.as_list:
yield self._add_service(s["url"], s["token"])
- hs = HomeServer(
- "test", db_pool=self.db_pool, clock=MockClock(), config=Mock(
- app_service_config_files=self.as_yaml_files
- )
- )
+ self.as_yaml_files = []
+
self.store = TestTransactionStore(hs)
def _add_service(self, url, as_token):
@@ -302,7 +301,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
(service.id,)
)
self.assertEquals(1, len(res))
- self.assertEquals(str(txn_id), res[0][0])
+ self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?",
@@ -325,7 +324,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
(service.id,)
)
self.assertEquals(1, len(res))
- self.assertEquals(str(txn_id), res[0][0])
+ self.assertEquals(txn_id, res[0][0])
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
|