summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/appservice.py61
-rw-r--r--tests/storage/test_appservice.py171
2 files changed, 213 insertions, 19 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 0b272e82dd..37078f9ef0 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,13 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import simplejson
 from simplejson import JSONDecodeError
+import simplejson as json
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
 from synapse.api.errors import StoreError
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.storage.roommember import RoomsForUser
 from ._base import SQLBaseStore
 
@@ -142,7 +142,7 @@ class ApplicationServiceStore(SQLBaseStore):
                     txn.execute(
                         "INSERT INTO application_services_regex("
                         "as_id, namespace, regex) values(?,?,?)",
-                        (as_id, ns_int, simplejson.dumps(regex_obj))
+                        (as_id, ns_int, json.dumps(regex_obj))
                     )
         return True
 
@@ -277,12 +277,7 @@ class ApplicationServiceStore(SQLBaseStore):
 
         return rooms_for_user_matching_user_id
 
-    @defer.inlineCallbacks
-    def _populate_cache(self):
-        """Populates the ApplicationServiceCache from the database."""
-        sql = ("SELECT * FROM application_services LEFT JOIN "
-               "application_services_regex ON application_services.id = "
-               "application_services_regex.as_id")
+    def _parse_services_dict(self, results):
         # SQL results in the form:
         # [
         #   {
@@ -296,13 +291,12 @@ class ApplicationServiceStore(SQLBaseStore):
         #   }
         # ]
         services = {}
-        results = yield self._execute_and_decode(sql)
         for res in results:
             as_token = res["token"]
             if as_token not in services:
                 # add the service
                 services[as_token] = {
-                    "id": res["as_id"],
+                    "id": res["id"],
                     "url": res["url"],
                     "token": as_token,
                     "hs_token": res["hs_token"],
@@ -320,16 +314,16 @@ class ApplicationServiceStore(SQLBaseStore):
             try:
                 services[as_token]["namespaces"][
                     ApplicationService.NS_LIST[ns_int]].append(
-                    simplejson.loads(res["regex"])
+                    json.loads(res["regex"])
                 )
             except IndexError:
                 logger.error("Bad namespace enum '%s'. %s", ns_int, res)
             except JSONDecodeError:
                 logger.error("Bad regex object '%s'", res["regex"])
 
+        service_list = []
         for service in services.values():
-            logger.info("Found application service: %s", service)
-            self.services_cache.append(ApplicationService(
+            service_list.append(ApplicationService(
                 token=service["token"],
                 url=service["url"],
                 namespaces=service["namespaces"],
@@ -337,6 +331,21 @@ class ApplicationServiceStore(SQLBaseStore):
                 sender=service["sender"],
                 id=service["id"]
             ))
+        return service_list
+
+    @defer.inlineCallbacks
+    def _populate_cache(self):
+        """Populates the ApplicationServiceCache from the database."""
+        sql = ("SELECT * FROM application_services LEFT JOIN "
+               "application_services_regex ON application_services.id = "
+               "application_services_regex.as_id")
+
+        results = yield self._execute_and_decode(sql)
+        services = self._parse_services_dict(results)
+
+        for service in services:
+            logger.info("Found application service: %s", service)
+            self.services_cache.append(service)
 
 
 class ApplicationServiceTransactionStore(SQLBaseStore):
@@ -344,6 +353,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
     def __init__(self, hs):
         super(ApplicationServiceTransactionStore, self).__init__(hs)
 
+    @defer.inlineCallbacks
     def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
 
@@ -353,8 +363,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
-        pass
+        sql = (
+            "SELECT r.*, a.* FROM application_services_state AS s LEFT JOIN "
+            "application_services AS a ON a.id=s.as_id LEFT JOIN "
+            "application_services_regex AS r ON r.as_id=a.id WHERE state = ?"
+        )
+        results = yield self._execute_and_decode(sql, state)
+        # NB: This assumes this class is linked with ApplicationServiceStore
+        defer.returnValue(self._parse_services_dict(results))
 
+    @defer.inlineCallbacks
     def get_appservice_state(self, service):
         """Get the application service state.
 
@@ -363,7 +381,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         Returns:
             A Deferred which resolves to ApplicationServiceState.
         """
-        pass
+        result = yield self._simple_select_one(
+            "application_services_state",
+            dict(as_id=service.id),
+            ["state"],
+            allow_none=True
+        )
+        if result:
+            defer.returnValue(result.get("state"))
+            return
+        defer.returnValue(None)
 
     def set_appservice_state(self, service, state):
         """Set the application service state.
@@ -372,7 +399,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
         Returns:
-            A Deferred which resolves to True if the state was set successfully.
+            A Deferred which resolves when the state was set successfully.
         """
         return self._simple_upsert(
             "application_services_state",
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ca5b92ec85..30c0b43d96 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,9 +15,11 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.server import HomeServer
-from synapse.storage.appservice import ApplicationServiceStore
+from synapse.storage.appservice import (
+    ApplicationServiceStore, ApplicationServiceTransactionStore
+)
 
 from mock import Mock
 from tests.utils import SQLiteMemoryDbPool, MockClock
@@ -114,3 +116,168 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
     def test_retrieval_of_all_services(self):
         services = yield self.store.get_app_services()
         self.assertEquals(len(services), 3)
+
+
+class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.db_pool = SQLiteMemoryDbPool()
+        yield self.db_pool.prepare()
+        hs = HomeServer(
+            "test", db_pool=self.db_pool, clock=MockClock(), config=Mock()
+        )
+        self.as_list = [
+            {
+                "token": "token1",
+                "url": "https://matrix-as.org",
+                "id": 3
+            },
+            {
+                "token": "alpha_tok",
+                "url": "https://alpha.com",
+                "id": 5
+            },
+            {
+                "token": "beta_tok",
+                "url": "https://beta.com",
+                "id": 6
+            },
+            {
+                "token": "delta_tok",
+                "url": "https://delta.com",
+                "id": 7
+            },
+        ]
+        for s in self.as_list:
+            yield self._add_service(s["id"], s["url"], s["token"])
+        self.store = TestTransactionStore(hs)
+
+    def _add_service(self, as_id, url, token):
+        return self.db_pool.runQuery(
+            "INSERT INTO application_services(id, url, token) VALUES(?,?,?)",
+            (as_id, url, token)
+        )
+
+    def _set_state(self, id, state, txn=None):
+        return self.db_pool.runQuery(
+            "INSERT INTO application_services_state(as_id, state, last_txn) "
+            "VALUES(?,?,?)",
+            (id, state, txn)
+        )
+
+    @defer.inlineCallbacks
+    def test_get_appservice_state_none(self):
+        service = Mock(id=999)
+        state = yield self.store.get_appservice_state(service)
+        self.assertEquals(None, state)
+
+    @defer.inlineCallbacks
+    def test_get_appservice_state_up(self):
+        yield self._set_state(
+            self.as_list[0]["id"], ApplicationServiceState.UP
+        )
+        service = Mock(id=self.as_list[0]["id"])
+        state = yield self.store.get_appservice_state(service)
+        self.assertEquals(ApplicationServiceState.UP, state)
+
+    @defer.inlineCallbacks
+    def test_get_appservice_state_down(self):
+        yield self._set_state(
+            self.as_list[0]["id"], ApplicationServiceState.UP
+        )
+        yield self._set_state(
+            self.as_list[1]["id"], ApplicationServiceState.DOWN
+        )
+        yield self._set_state(
+            self.as_list[2]["id"], ApplicationServiceState.DOWN
+        )
+        service = Mock(id=self.as_list[1]["id"])
+        state = yield self.store.get_appservice_state(service)
+        self.assertEquals(ApplicationServiceState.DOWN, state)
+
+    @defer.inlineCallbacks
+    def test_get_appservices_by_state_none(self):
+        services = yield self.store.get_appservices_by_state(
+            ApplicationServiceState.DOWN
+        )
+        self.assertEquals(0, len(services))
+
+    @defer.inlineCallbacks
+    def test_set_appservices_state_down(self):
+        service = Mock(id=self.as_list[1]["id"])
+        yield self.store.set_appservice_state(
+            service,
+            ApplicationServiceState.DOWN
+        )
+        rows = yield self.db_pool.runQuery(
+            "SELECT as_id FROM application_services_state WHERE state=?",
+            (ApplicationServiceState.DOWN,)
+        )
+        self.assertEquals(service.id, rows[0][0])
+
+    @defer.inlineCallbacks
+    def test_set_appservices_state_multiple_up(self):
+        service = Mock(id=self.as_list[1]["id"])
+        yield self.store.set_appservice_state(
+            service,
+            ApplicationServiceState.UP
+        )
+        yield self.store.set_appservice_state(
+            service,
+            ApplicationServiceState.DOWN
+        )
+        yield self.store.set_appservice_state(
+            service,
+            ApplicationServiceState.UP
+        )
+        rows = yield self.db_pool.runQuery(
+            "SELECT as_id FROM application_services_state WHERE state=?",
+            (ApplicationServiceState.UP,)
+        )
+        self.assertEquals(service.id, rows[0][0])
+
+    @defer.inlineCallbacks
+    def test_get_appservices_by_state_single(self):
+        yield self._set_state(
+            self.as_list[0]["id"], ApplicationServiceState.DOWN
+        )
+        yield self._set_state(
+            self.as_list[1]["id"], ApplicationServiceState.UP
+        )
+
+        services = yield self.store.get_appservices_by_state(
+            ApplicationServiceState.DOWN
+        )
+        self.assertEquals(1, len(services))
+        self.assertEquals(self.as_list[0]["id"], services[0].id)
+
+    @defer.inlineCallbacks
+    def test_get_appservices_by_state_multiple(self):
+        yield self._set_state(
+            self.as_list[0]["id"], ApplicationServiceState.DOWN
+        )
+        yield self._set_state(
+            self.as_list[1]["id"], ApplicationServiceState.UP
+        )
+        yield self._set_state(
+            self.as_list[2]["id"], ApplicationServiceState.DOWN
+        )
+        yield self._set_state(
+            self.as_list[3]["id"], ApplicationServiceState.UP
+        )
+
+        services = yield self.store.get_appservices_by_state(
+            ApplicationServiceState.DOWN
+        )
+        self.assertEquals(2, len(services))
+        self.assertEquals(self.as_list[2]["id"], services[0].id)
+        self.assertEquals(self.as_list[0]["id"], services[1].id)
+
+
+# required for ApplicationServiceTransactionStoreTestCase tests
+class TestTransactionStore(ApplicationServiceTransactionStore,
+                           ApplicationServiceStore):
+
+    def __init__(self, hs):
+        super(TestTransactionStore, self).__init__(hs)