summary refs log tree commit diff
path: root/synapse/util/task_scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/task_scheduler.py')
-rw-r--r--synapse/util/task_scheduler.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 9e89aeb748..9b2581e51a 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -77,6 +77,7 @@ class TaskScheduler:
     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
 
     def __init__(self, hs: "HomeServer"):
+        self._hs = hs
         self._store = hs.get_datastores().main
         self._clock = hs.get_clock()
         self._running_tasks: Set[str] = set()
@@ -97,8 +98,6 @@ class TaskScheduler:
                 "handle_scheduled_tasks",
                 self._handle_scheduled_tasks,
             )
-        else:
-            self.replication_client = hs.get_replication_command_handler()
 
     def register_action(
         self,
@@ -133,7 +132,7 @@ class TaskScheduler:
         params: Optional[JsonMapping] = None,
     ) -> str:
         """Schedule a new potentially resumable task. A function matching the specified
-        `action` should have been previously registered with `register_action`.
+        `action` should have be registered with `register_action` before the task is run.
 
         Args:
             action: the name of a previously registered action
@@ -149,11 +148,6 @@ class TaskScheduler:
         Returns:
             The id of the scheduled task
         """
-        if action not in self._actions:
-            raise Exception(
-                f"No function associated with action {action} of the scheduled task"
-            )
-
         status = TaskStatus.SCHEDULED
         if timestamp is None or timestamp < self._clock.time_msec():
             timestamp = self._clock.time_msec()
@@ -175,7 +169,7 @@ class TaskScheduler:
             if self._run_background_tasks:
                 await self._launch_task(task)
             else:
-                self.replication_client.send_new_active_task(task.id)
+                self._hs.get_replication_command_handler().send_new_active_task(task.id)
 
         return task.id
 
@@ -315,7 +309,10 @@ class TaskScheduler:
         """
         assert self._run_background_tasks
 
-        assert task.action in self._actions
+        if task.action not in self._actions:
+            raise Exception(
+                f"No function associated with action {task.action} of the scheduled task {task.id}"
+            )
         function = self._actions[task.action]
 
         async def wrapper() -> None: