summary refs log tree commit diff
path: root/synapse/appservice/scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice/scheduler.py')
-rw-r--r--synapse/appservice/scheduler.py74
1 files changed, 44 insertions, 30 deletions
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6a2ce99b55..185e3a5278 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 import logging
-from typing import List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.appservice.api import ApplicationServiceApi
 from synapse.events import EventBase
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main import DataStore
 from synapse.types import JsonDict
+from synapse.util import Clock
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
     case is a simple array.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.as_api = hs.get_application_service_api()
@@ -80,7 +86,7 @@ class ApplicationServiceScheduler:
         self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
         self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
 
-    async def start(self):
+    async def start(self) -> None:
         logger.info("Starting appservice scheduler")
 
         # check for any DOWN ASes and start recoverers for them.
@@ -91,12 +97,14 @@ class ApplicationServiceScheduler:
         for service in services:
             self.txn_ctrl.start_recoverer(service)
 
-    def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+    def submit_event_for_as(
+        self, service: ApplicationService, event: EventBase
+    ) -> None:
         self.queuer.enqueue_event(service, event)
 
     def submit_ephemeral_events_for_as(
         self, service: ApplicationService, events: List[JsonDict]
-    ):
+    ) -> None:
         self.queuer.enqueue_ephemeral(service, events)
 
 
@@ -108,16 +116,18 @@ class _ServiceQueuer:
     appservice at a given time.
     """
 
-    def __init__(self, txn_ctrl, clock):
-        self.queued_events = {}  # dict of {service_id: [events]}
-        self.queued_ephemeral = {}  # dict of {service_id: [events]}
+    def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
+        # dict of {service_id: [events]}
+        self.queued_events: Dict[str, List[EventBase]] = {}
+        # dict of {service_id: [events]}
+        self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
 
         # the appservices which currently have a transaction in flight
-        self.requests_in_flight = set()
+        self.requests_in_flight: Set[str] = set()
         self.txn_ctrl = txn_ctrl
         self.clock = clock
 
-    def _start_background_request(self, service):
+    def _start_background_request(self, service: ApplicationService) -> None:
         # start a sender for this appservice if we don't already have one
         if service.id in self.requests_in_flight:
             return
@@ -126,15 +136,17 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id,), self._send_request, service
         )
 
-    def enqueue_event(self, service: ApplicationService, event: EventBase):
+    def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
         self.queued_events.setdefault(service.id, []).append(event)
         self._start_background_request(service)
 
-    def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+    def enqueue_ephemeral(
+        self, service: ApplicationService, events: List[JsonDict]
+    ) -> None:
         self.queued_ephemeral.setdefault(service.id, []).extend(events)
         self._start_background_request(service)
 
-    async def _send_request(self, service: ApplicationService):
+    async def _send_request(self, service: ApplicationService) -> None:
         # sanity-check: we shouldn't get here if this service already has a sender
         # running.
         assert service.id not in self.requests_in_flight
@@ -168,20 +180,15 @@ class _TransactionController:
     if a transaction fails.
 
     (Note we have only have one of these in the homeserver.)
-
-    Args:
-        clock (synapse.util.Clock):
-        store (synapse.storage.DataStore):
-        as_api (synapse.appservice.api.ApplicationServiceApi):
     """
 
-    def __init__(self, clock, store, as_api):
+    def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
         self.clock = clock
         self.store = store
         self.as_api = as_api
 
         # map from service id to recoverer instance
-        self.recoverers = {}
+        self.recoverers: Dict[str, "_Recoverer"] = {}
 
         # for UTs
         self.RECOVERER_CLASS = _Recoverer
@@ -191,7 +198,7 @@ class _TransactionController:
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: Optional[List[JsonDict]] = None,
-    ):
+    ) -> None:
         try:
             txn = await self.store.create_appservice_txn(
                 service=service, events=events, ephemeral=ephemeral or []
@@ -207,7 +214,7 @@ class _TransactionController:
             logger.exception("Error creating appservice transaction")
             run_in_background(self._on_txn_fail, service)
 
-    async def on_recovered(self, recoverer):
+    async def on_recovered(self, recoverer: "_Recoverer") -> None:
         logger.info(
             "Successfully recovered application service AS ID %s", recoverer.service.id
         )
@@ -217,18 +224,18 @@ class _TransactionController:
             recoverer.service, ApplicationServiceState.UP
         )
 
-    async def _on_txn_fail(self, service):
+    async def _on_txn_fail(self, service: ApplicationService) -> None:
         try:
             await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
             self.start_recoverer(service)
         except Exception:
             logger.exception("Error starting AS recoverer")
 
-    def start_recoverer(self, service):
+    def start_recoverer(self, service: ApplicationService) -> None:
         """Start a Recoverer for the given service
 
         Args:
-            service (synapse.appservice.ApplicationService):
+            service:
         """
         logger.info("Starting recoverer for AS ID %s", service.id)
         assert service.id not in self.recoverers
@@ -257,7 +264,14 @@ class _Recoverer:
         callback (callable[_Recoverer]): called once the service recovers.
     """
 
-    def __init__(self, clock, store, as_api, service, callback):
+    def __init__(
+        self,
+        clock: Clock,
+        store: DataStore,
+        as_api: ApplicationServiceApi,
+        service: ApplicationService,
+        callback: Callable[["_Recoverer"], Awaitable[None]],
+    ):
         self.clock = clock
         self.store = store
         self.as_api = as_api
@@ -265,8 +279,8 @@ class _Recoverer:
         self.callback = callback
         self.backoff_counter = 1
 
-    def recover(self):
-        def _retry():
+    def recover(self) -> None:
+        def _retry() -> None:
             run_as_background_process(
                 "as-recoverer-%s" % (self.service.id,), self.retry
             )
@@ -275,13 +289,13 @@ class _Recoverer:
         logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
         self.clock.call_later(delay, _retry)
 
-    def _backoff(self):
+    def _backoff(self) -> None:
         # cap the backoff to be around 8.5min => (2^9) = 512 secs
         if self.backoff_counter < 9:
             self.backoff_counter += 1
         self.recover()
 
-    async def retry(self):
+    async def retry(self) -> None:
         logger.info("Starting retries on %s", self.service.id)
         try:
             while True: