summary refs log tree commit diff
path: root/tests/storage/test_appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_appservice.py')
-rw-r--r--tests/storage/test_appservice.py86
1 files changed, 52 insertions, 34 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 622b16a071..ef296e7dab 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,10 +24,11 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -42,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = self.as_yaml_files
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         self.as_token = "token1"
@@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -65,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
                 pass
 
     def _add_appservice(self, as_token, id, url, hs_token, sender):
-        as_yaml = dict(
-            url=url,
-            as_token=as_token,
-            hs_token=hs_token,
-            id=id,
-            sender_localpart=sender,
-            namespaces={},
-        )
+        as_yaml = {
+            "url": url,
+            "as_token": as_token,
+            "hs_token": hs_token,
+            "id": id,
+            "sender_localpart": sender,
+            "namespaces": {},
+        }
         # use the token as the filename
         with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -106,12 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = self.as_yaml_files
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -123,17 +124,25 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
 
-    def _add_service(self, url, as_token, id):
-        as_yaml = dict(
-            url=url,
-            as_token=as_token,
-            hs_token="something",
-            id=id,
-            sender_localpart="a_sender",
-            namespaces={},
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
         )
+
+    def _add_service(self, url, as_token, id):
+        as_yaml = {
+            "url": url,
+            "as_token": as_token,
+            "hs_token": "something",
+            "id": id,
+            "sender_localpart": "a_sender",
+            "namespaces": {},
+        }
         # use the token as the filename
         with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         )
         self.assertEquals(2, len(services))
         self.assertEquals(
-            set([self.as_list[2]["id"], self.as_list[0]["id"]]),
-            set([services[0].id, services[1].id]),
+            {self.as_list[2]["id"], self.as_list[0]["id"]},
+            {services[0].id, services[1].id},
         )
 
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -413,10 +422,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -428,11 +440,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -449,11 +464,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))