diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 3a97559bf0..8665aeb50c 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus
from synapse.util import Clock
from synapse.util.task_scheduler import TaskScheduler
-from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase, override_config
-class TestTaskScheduler(unittest.HomeserverTestCase):
+class TestTaskScheduler(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.task_scheduler = hs.get_task_scheduler()
self.task_scheduler.register_action(self._test_task, "_test_task")
@@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
async def _test_task(
- self, task: ScheduledTask, first_launch: bool
+ self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
# This test task will copy the parameters to the result
result = None
@@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.assertIsNone(task)
async def _sleeping_task(
- self, task: ScheduledTask, first_launch: bool
+ self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
# Sleep for a second
await deferLater(self.reactor, 1, lambda: None)
@@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
def test_schedule_lot_of_tasks(self) -> None:
"""Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
- timestamp = self.clock.time_msec() + 30 * 1000
task_ids = []
for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
task_ids.append(
self.get_success(
self.task_scheduler.schedule_task(
"_sleeping_task",
- timestamp=timestamp,
params={"val": i},
)
)
)
- # The timestamp being 30s after now the task should been executed
- # after the first scheduling loop is run
- self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
- # This is to give the time to the sleeping tasks to finish
+ # This is to give the time to the active tasks to finish
self.reactor.advance(1)
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
@@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
)
scheduled_tasks = [
- t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
+ t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
]
self.assertEquals(len(scheduled_tasks), 1)
+ # We need to wait for the next run of the scheduler loop
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
self.reactor.advance(1)
@@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
)
async def _raising_task(
- self, task: ScheduledTask, first_launch: bool
+ self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
raise Exception("raising")
@@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
"""Schedule a task raising an exception and check it runs to failure and report exception content."""
task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
- self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
self.assertEqual(task.status, TaskStatus.FAILED)
self.assertEqual(task.error, "raising")
async def _resumable_task(
- self, task: ScheduledTask, first_launch: bool
+ self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
if task.result and "in_progress" in task.result:
return TaskStatus.COMPLETE, {"success": True}, None
@@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
"""Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
- self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
self.assertEqual(task.status, TaskStatus.ACTIVE)
@@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
self.assertEqual(task.status, TaskStatus.COMPLETE)
assert task.result is not None
self.assertTrue(task.result.get("success"))
+
+
+class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.task_scheduler = hs.get_task_scheduler()
+ self.task_scheduler.register_action(self._test_task, "_test_task")
+
+ async def _test_task(
+ self, task: ScheduledTask
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return (TaskStatus.COMPLETE, None, None)
+
+ @override_config({"run_background_tasks_on": "worker1"})
+ def test_schedule_task(self) -> None:
+ """Check that a task scheduled to run now is launch right away on the background worker."""
+ bg_worker_hs = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={"worker_name": "worker1"},
+ )
+ bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
+
+ task_id = self.get_success(
+ self.task_scheduler.schedule_task(
+ "_test_task",
+ )
+ )
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
|