diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 0b272e82dd..37078f9ef0 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import simplejson
from simplejson import JSONDecodeError
+import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import StoreError
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore
@@ -142,7 +142,7 @@ class ApplicationServiceStore(SQLBaseStore):
txn.execute(
"INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)",
- (as_id, ns_int, simplejson.dumps(regex_obj))
+ (as_id, ns_int, json.dumps(regex_obj))
)
return True
@@ -277,12 +277,7 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
- @defer.inlineCallbacks
- def _populate_cache(self):
- """Populates the ApplicationServiceCache from the database."""
- sql = ("SELECT * FROM application_services LEFT JOIN "
- "application_services_regex ON application_services.id = "
- "application_services_regex.as_id")
+ def _parse_services_dict(self, results):
# SQL results in the form:
# [
# {
@@ -296,13 +291,12 @@ class ApplicationServiceStore(SQLBaseStore):
# }
# ]
services = {}
- results = yield self._execute_and_decode(sql)
for res in results:
as_token = res["token"]
if as_token not in services:
# add the service
services[as_token] = {
- "id": res["as_id"],
+ "id": res["id"],
"url": res["url"],
"token": as_token,
"hs_token": res["hs_token"],
@@ -320,16 +314,16 @@ class ApplicationServiceStore(SQLBaseStore):
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
- simplejson.loads(res["regex"])
+ json.loads(res["regex"])
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
+ service_list = []
for service in services.values():
- logger.info("Found application service: %s", service)
- self.services_cache.append(ApplicationService(
+ service_list.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
@@ -337,6 +331,21 @@ class ApplicationServiceStore(SQLBaseStore):
sender=service["sender"],
id=service["id"]
))
+ return service_list
+
+ @defer.inlineCallbacks
+ def _populate_cache(self):
+ """Populates the ApplicationServiceCache from the database."""
+ sql = ("SELECT * FROM application_services LEFT JOIN "
+ "application_services_regex ON application_services.id = "
+ "application_services_regex.as_id")
+
+ results = yield self._execute_and_decode(sql)
+ services = self._parse_services_dict(results)
+
+ for service in services:
+ logger.info("Found application service: %s", service)
+ self.services_cache.append(service)
class ApplicationServiceTransactionStore(SQLBaseStore):
@@ -344,6 +353,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceTransactionStore, self).__init__(hs)
+ @defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@@ -353,8 +363,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
- pass
+ sql = (
+ "SELECT r.*, a.* FROM application_services_state AS s LEFT JOIN "
+ "application_services AS a ON a.id=s.as_id LEFT JOIN "
+ "application_services_regex AS r ON r.as_id=a.id WHERE state = ?"
+ )
+ results = yield self._execute_and_decode(sql, state)
+ # NB: This assumes this class is linked with ApplicationServiceStore
+ defer.returnValue(self._parse_services_dict(results))
+ @defer.inlineCallbacks
def get_appservice_state(self, service):
"""Get the application service state.
@@ -363,7 +381,16 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
- pass
+ result = yield self._simple_select_one(
+ "application_services_state",
+ dict(as_id=service.id),
+ ["state"],
+ allow_none=True
+ )
+ if result:
+ defer.returnValue(result.get("state"))
+ return
+ defer.returnValue(None)
def set_appservice_state(self, service, state):
"""Set the application service state.
@@ -372,7 +399,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
- A Deferred which resolves to True if the state was set successfully.
+ A Deferred which resolves when the state was set successfully.
"""
return self._simple_upsert(
"application_services_state",
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ca5b92ec85..30c0b43d96 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,9 +15,11 @@
from tests import unittest
from twisted.internet import defer
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer
-from synapse.storage.appservice import ApplicationServiceStore
+from synapse.storage.appservice import (
+ ApplicationServiceStore, ApplicationServiceTransactionStore
+)
from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock
@@ -114,3 +116,168 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services()
self.assertEquals(len(services), 3)
+
+
+class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.db_pool = SQLiteMemoryDbPool()
+ yield self.db_pool.prepare()
+ hs = HomeServer(
+ "test", db_pool=self.db_pool, clock=MockClock(), config=Mock()
+ )
+ self.as_list = [
+ {
+ "token": "token1",
+ "url": "https://matrix-as.org",
+ "id": 3
+ },
+ {
+ "token": "alpha_tok",
+ "url": "https://alpha.com",
+ "id": 5
+ },
+ {
+ "token": "beta_tok",
+ "url": "https://beta.com",
+ "id": 6
+ },
+ {
+ "token": "delta_tok",
+ "url": "https://delta.com",
+ "id": 7
+ },
+ ]
+ for s in self.as_list:
+ yield self._add_service(s["id"], s["url"], s["token"])
+ self.store = TestTransactionStore(hs)
+
+ def _add_service(self, as_id, url, token):
+ return self.db_pool.runQuery(
+ "INSERT INTO application_services(id, url, token) VALUES(?,?,?)",
+ (as_id, url, token)
+ )
+
+ def _set_state(self, id, state, txn=None):
+ return self.db_pool.runQuery(
+ "INSERT INTO application_services_state(as_id, state, last_txn) "
+ "VALUES(?,?,?)",
+ (id, state, txn)
+ )
+
+ @defer.inlineCallbacks
+ def test_get_appservice_state_none(self):
+ service = Mock(id=999)
+ state = yield self.store.get_appservice_state(service)
+ self.assertEquals(None, state)
+
+ @defer.inlineCallbacks
+ def test_get_appservice_state_up(self):
+ yield self._set_state(
+ self.as_list[0]["id"], ApplicationServiceState.UP
+ )
+ service = Mock(id=self.as_list[0]["id"])
+ state = yield self.store.get_appservice_state(service)
+ self.assertEquals(ApplicationServiceState.UP, state)
+
+ @defer.inlineCallbacks
+ def test_get_appservice_state_down(self):
+ yield self._set_state(
+ self.as_list[0]["id"], ApplicationServiceState.UP
+ )
+ yield self._set_state(
+ self.as_list[1]["id"], ApplicationServiceState.DOWN
+ )
+ yield self._set_state(
+ self.as_list[2]["id"], ApplicationServiceState.DOWN
+ )
+ service = Mock(id=self.as_list[1]["id"])
+ state = yield self.store.get_appservice_state(service)
+ self.assertEquals(ApplicationServiceState.DOWN, state)
+
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_none(self):
+ services = yield self.store.get_appservices_by_state(
+ ApplicationServiceState.DOWN
+ )
+ self.assertEquals(0, len(services))
+
+ @defer.inlineCallbacks
+ def test_set_appservices_state_down(self):
+ service = Mock(id=self.as_list[1]["id"])
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ rows = yield self.db_pool.runQuery(
+ "SELECT as_id FROM application_services_state WHERE state=?",
+ (ApplicationServiceState.DOWN,)
+ )
+ self.assertEquals(service.id, rows[0][0])
+
+ @defer.inlineCallbacks
+ def test_set_appservices_state_multiple_up(self):
+ service = Mock(id=self.as_list[1]["id"])
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.UP
+ )
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.UP
+ )
+ rows = yield self.db_pool.runQuery(
+ "SELECT as_id FROM application_services_state WHERE state=?",
+ (ApplicationServiceState.UP,)
+ )
+ self.assertEquals(service.id, rows[0][0])
+
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_single(self):
+ yield self._set_state(
+ self.as_list[0]["id"], ApplicationServiceState.DOWN
+ )
+ yield self._set_state(
+ self.as_list[1]["id"], ApplicationServiceState.UP
+ )
+
+ services = yield self.store.get_appservices_by_state(
+ ApplicationServiceState.DOWN
+ )
+ self.assertEquals(1, len(services))
+ self.assertEquals(self.as_list[0]["id"], services[0].id)
+
+ @defer.inlineCallbacks
+ def test_get_appservices_by_state_multiple(self):
+ yield self._set_state(
+ self.as_list[0]["id"], ApplicationServiceState.DOWN
+ )
+ yield self._set_state(
+ self.as_list[1]["id"], ApplicationServiceState.UP
+ )
+ yield self._set_state(
+ self.as_list[2]["id"], ApplicationServiceState.DOWN
+ )
+ yield self._set_state(
+ self.as_list[3]["id"], ApplicationServiceState.UP
+ )
+
+ services = yield self.store.get_appservices_by_state(
+ ApplicationServiceState.DOWN
+ )
+ self.assertEquals(2, len(services))
+ self.assertEquals(self.as_list[2]["id"], services[0].id)
+ self.assertEquals(self.as_list[0]["id"], services[1].id)
+
+
+# required for ApplicationServiceTransactionStoreTestCase tests
+class TestTransactionStore(ApplicationServiceTransactionStore,
+ ApplicationServiceStore):
+
+ def __init__(self, hs):
+ super(TestTransactionStore, self).__init__(hs)
|