summary refs log tree commit diff
path: root/tests/storage/test_appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_appservice.py')
-rw-r--r--tests/storage/test_appservice.py131
1 files changed, 68 insertions, 63 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index c893990454..3f0083831b 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.as_yaml_files = []
-        config = Mock(
-            app_service_config_files=self.as_yaml_files,
-            event_cache_size=1,
-            password_providers=[],
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = self.as_yaml_files
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         self.as_token = "token1"
         self.as_url = "some_url"
         self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(None, hs)
+        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def setUp(self):
         self.as_yaml_files = []
 
-        config = Mock(
-            app_service_config_files=self.as_yaml_files,
-            event_cache_size=1,
-            password_providers=[],
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
+
+        hs.config.app_service_config_files = self.as_yaml_files
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         self.db_pool = hs.get_db_pool()
+        self.engine = hs.database_engine
 
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(None, hs)
+        self.store = TestTransactionStore(hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             self.as_yaml_files.append(as_token)
 
     def _set_state(self, id, state, txn=None):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_state(as_id, state, last_txn) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_state(as_id, state, last_txn) "
+                "VALUES(?,?,?)"
+            ),
             (id, state, txn),
         )
 
     def _insert_txn(self, as_id, txn_id, events):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+                "VALUES(?,?,?)"
+            ),
             (as_id, txn_id, json.dumps([e.event_id for e in events])),
         )
 
     def _set_last_txn(self, as_id, txn_id):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_state(as_id, last_txn, state) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_state(as_id, last_txn, state) "
+                "VALUES(?,?,?)"
+            ),
             (as_id, txn_id, ApplicationServiceState.UP),
         )
 
     @defer.inlineCallbacks
     def test_get_appservice_state_none(self):
-        service = Mock(id=999)
+        service = Mock(id="999")
         state = yield self.store.get_appservice_state(service)
         self.assertEquals(None, state)
 
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[1]["id"])
         yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
         rows = yield self.db_pool.runQuery(
-            "SELECT as_id FROM application_services_state WHERE state=?",
+            self.engine.convert_param_style(
+                "SELECT as_id FROM application_services_state WHERE state=?"
+            ),
             (ApplicationServiceState.DOWN,),
         )
         self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
         yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
         rows = yield self.db_pool.runQuery(
-            "SELECT as_id FROM application_services_state WHERE state=?",
+            self.engine.convert_param_style(
+                "SELECT as_id FROM application_services_state WHERE state=?"
+            ),
             (ApplicationServiceState.UP,),
         )
         self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
 
         res = yield self.db_pool.runQuery(
-            "SELECT last_txn FROM application_services_state WHERE as_id=?",
+            self.engine.convert_param_style(
+                "SELECT last_txn FROM application_services_state WHERE as_id=?"
+            ),
             (service.id,),
         )
         self.assertEquals(1, len(res))
         self.assertEquals(txn_id, res[0][0])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+            self.engine.convert_param_style(
+                "SELECT * FROM application_services_txns WHERE txn_id=?"
+            ),
+            (txn_id,),
         )
         self.assertEquals(0, len(res))
 
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
 
         res = yield self.db_pool.runQuery(
-            "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
+            self.engine.convert_param_style(
+                "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+            ),
             (service.id,),
         )
         self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         self.assertEquals(ApplicationServiceState.UP, res[0][1])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+            self.engine.convert_param_style(
+                "SELECT * FROM application_services_txns WHERE txn_id=?"
+            ),
+            (txn_id,),
         )
         self.assertEquals(0, len(res))
 
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f1 = self._write_config(suffix="1")
         f2 = self._write_config(suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
-        ApplicationServiceStore(None, hs)
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
+        ApplicationServiceStore(hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
         f1 = self._write_config(id="id", suffix="1")
         f2 = self._write_config(id="id", suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(None, hs)
+            ApplicationServiceStore(hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f1 = self._write_config(as_token="as_token", suffix="1")
         f2 = self._write_config(as_token="as_token", suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(None, hs)
+            ApplicationServiceStore(hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))