summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/appservice/__init__.py5
-rw-r--r--synapse/appservice/scheduler.py37
-rw-r--r--synapse/storage/appservice.py17
-rw-r--r--tests/appservice/test_scheduler.py5
4 files changed, 52 insertions, 12 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a268a6bcc4..cc6c381566 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -20,6 +20,11 @@ import re
 logger = logging.getLogger(__name__)
 
 
+class ApplicationServiceState(object):
+    DOWN = "down"
+    UP = "up"
+
+
 class ApplicationService(object):
     """Defines an application service. This definition is mostly what is
     provided to the /register AS API.
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 645d7bf6b2..99e83747a8 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -49,7 +49,11 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 
+from synapse.appservice import ApplicationServiceState
 from twisted.internet import defer
+import logging
+
+logger = logging.getLogger(__name__)
 
 
 class AppServiceScheduler(object):
@@ -162,21 +166,36 @@ class _TransactionController(object):
                 if txn.send(self.as_api):
                     txn.complete(self.store)
                 else:
-                    # TODO mark AS as down
                     self._start_recoverer(service)
         self.clock.call_later(1000, self.start_polling)
 
-    def on_recovered(self, service):
-        # TODO mark AS as UP
-        pass
+    @defer.inlineCallbacks
+    def on_recovered(self, recoverer):
+        applied_state = yield self.store.set_appservice_state(
+            recoverer.service,
+            ApplicationServiceState.UP
+        )
+        if not applied_state:
+            logger.error("Failed to apply appservice state UP to service %s",
+                         recoverer.service)
 
     def add_recoverers(self, recoverers):
         for r in recoverers:
             self.recoverers.append(r)
 
+    @defer.inlineCallbacks
     def _start_recoverer(self, service):
-        recoverer = self.recoverer_fn(service, self.on_recovered)
-        recoverer.recover()
+        applied_state = yield self.store.set_appservice_state(
+            service,
+            ApplicationServiceState.DOWN
+        )
+        if applied_state:
+            recoverer = self.recoverer_fn(service, self.on_recovered)
+            self.add_recoverers([recoverer])
+            recoverer.recover()
+        else:
+            logger.error("Failed to apply appservice state DOWN to service %s",
+                         service)
 
     def _is_service_up(self, service):
         pass
@@ -193,7 +212,9 @@ class _Recoverer(object):
     @staticmethod
     @defer.inlineCallbacks
     def start(clock, store, as_api, callback):
-        services = yield store.get_failing_appservices()
+        services = yield store.get_appservices_by_state(
+            ApplicationServiceState.DOWN
+        )
         recoverers = [
             _Recoverer(clock, store, as_api, s, callback) for s in services
         ]
@@ -228,7 +249,7 @@ class _Recoverer(object):
             self._set_service_recovered()
 
     def _set_service_recovered(self):
-        self.callback(self.service)
+        self.callback(self)
 
     @defer.inlineCallbacks
     def _get_oldest_txn(self):
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index c1762692b9..214f6d99c5 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -343,15 +343,28 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
     def __init__(self, hs):
         super(ApplicationServiceTransactionStore, self).__init__(hs)
 
-    def get_failing_appservices(self):
-        """Get a list of application services which are down.
+    def get_appservices_by_state(self, state):
+        """Get a list of application services based on their state.
 
+        Args:
+            state(ApplicationServiceState): The state to filter on.
         Returns:
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
         pass
 
+    def set_appservice_state(self, service, state):
+        """Set the application service state.
+
+        Args:
+            service(ApplicationService): The service whose state to set.
+            state(ApplicationServiceState): The connectivity state to apply.
+        Returns:
+            A Deferred which resolves to True if the state was set successfully.
+        """
+        pass
+
     def complete_appservice_txn(self, txn_id, service):
         """Completes an application service transaction.
 
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 1e3eb9e1cc..ec8f77c54b 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -57,7 +57,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEquals(1, txn.complete.call_count)
         # 2 because it needs to get None to know there are no more txns
         self.assertEquals(2, self.store.get_oldest_txn.call_count)
-        self.assertEquals(1, self.callback.call_count)
+        self.callback.assert_called_once_with(self.recoverer)
+        self.assertEquals(self.recoverer.service, self.service)
 
     def test_recover_retry_txn(self):
         txn = Mock()
@@ -91,7 +92,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.clock.advance_time(16000)
         self.assertEquals(1, txn.send.call_count)  # new mock reset call count
         self.assertEquals(1, txn.complete.call_count)
-        self.assertEquals(1, self.callback.call_count)
+        self.callback.assert_called_once_with(self.recoverer)
 
 class ApplicationServiceSchedulerEventGrouperTestCase(unittest.TestCase):