summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/appservice.py61
1 files changed, 44 insertions, 17 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 0b272e82dd..37078f9ef0 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,13 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import simplejson
 from simplejson import JSONDecodeError
+import simplejson as json
 from twisted.internet import defer
 
 from synapse.api.constants import Membership
 from synapse.api.errors import StoreError
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.storage.roommember import RoomsForUser
 from ._base import SQLBaseStore
 
@@ -142,7 +142,7 @@ class ApplicationServiceStore(SQLBaseStore):
                     txn.execute(
                         "INSERT INTO application_services_regex("
                         "as_id, namespace, regex) values(?,?,?)",
-                        (as_id, ns_int, simplejson.dumps(regex_obj))
+                        (as_id, ns_int, json.dumps(regex_obj))
                     )
         return True
 
@@ -277,12 +277,7 @@ class ApplicationServiceStore(SQLBaseStore):
 
         return rooms_for_user_matching_user_id
 
-    @defer.inlineCallbacks
-    def _populate_cache(self):
-        """Populates the ApplicationServiceCache from the database."""
-        sql = ("SELECT * FROM application_services LEFT JOIN "
-               "application_services_regex ON application_services.id = "
-               "application_services_regex.as_id")
+    def _parse_services_dict(self, results):
         # SQL results in the form:
         # [
         #   {
@@ -296,13 +291,12 @@ class ApplicationServiceStore(SQLBaseStore):
         #   }
         # ]
         services = {}
-        results = yield self._execute_and_decode(sql)
         for res in results:
             as_token = res["token"]
             if as_token not in services:
                 # add the service
                 services[as_token] = {
-                    "id": res["as_id"],
+                    "id": res["id"],
                     "url": res["url"],
                     "token": as_token,
                     "hs_token": res["hs_token"],
@@ -320,16 +314,16 @@ class ApplicationServiceStore(SQLBaseStore):
             try:
                 services[as_token]["namespaces"][
                     ApplicationService.NS_LIST[ns_int]].append(
-                    simplejson.loads(res["regex"])
+                    json.loads(res["regex"])
                 )
             except IndexError:
                 logger.error("Bad namespace enum '%s'. %s", ns_int, res)
             except JSONDecodeError:
                 logger.error("Bad regex object '%s'", res["regex"])
 
+        service_list = []
         for service in services.values():
-            logger.info("Found application service: %s", service)
-            self.services_cache.append(ApplicationService(
+            service_list.append(ApplicationService(
                 token=service["token"],
                 url=service["url"],
                 namespaces=service["namespaces"],
@@ -337,6 +331,21 @@ class ApplicationServiceStore(SQLBaseStore):
                 sender=service["sender"],
                 id=service["id"]
             ))
+        return service_list
+
+    @defer.inlineCallbacks
+    def _populate_cache(self):
+        """Populates the ApplicationServiceCache from the database."""
+        sql = ("SELECT * FROM application_services LEFT JOIN "
+               "application_services_regex ON application_services.id = "
+               "application_services_regex.as_id")
+
+        results = yield self._execute_and_decode(sql)
+        services = self._parse_services_dict(results)
+
+        for service in services:
+            logger.info("Found application service: %s", service)
+            self.services_cache.append(service)
 
 
 class ApplicationServiceTransactionStore(SQLBaseStore):
@@ -344,6 +353,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
     def __init__(self, hs):
         super(ApplicationServiceTransactionStore, self).__init__(hs)
 
+    @defer.inlineCallbacks
     def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
 
@@ -353,8 +363,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
-        pass
+        sql = (
+            "SELECT r.*, a.* FROM application_services_state AS s LEFT JOIN "
+            "application_services AS a ON a.id=s.as_id LEFT JOIN "
+            "application_services_regex AS r ON r.as_id=a.id WHERE state = ?"
+        )
+        results = yield self._execute_and_decode(sql, state)
+        # NB: This assumes this class is linked with ApplicationServiceStore
+        defer.returnValue(self._parse_services_dict(results))
 
+    @defer.inlineCallbacks
     def get_appservice_state(self, service):
         """Get the application service state.
 
@@ -363,7 +381,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         Returns:
             A Deferred which resolves to ApplicationServiceState.
         """
-        pass
+        result = yield self._simple_select_one(
+            "application_services_state",
+            dict(as_id=service.id),
+            ["state"],
+            allow_none=True
+        )
+        if result:
+            defer.returnValue(result.get("state"))
+            return
+        defer.returnValue(None)
 
     def set_appservice_state(self, service, state):
         """Set the application service state.
@@ -372,7 +399,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
         Returns:
-            A Deferred which resolves to True if the state was set successfully.
+            A Deferred which resolves when the state was set successfully.
         """
         return self._simple_upsert(
             "application_services_state",