summary refs log tree commit diff
path: root/tests/util/test_task_scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/test_task_scheduler.py')
-rw-r--r--tests/util/test_task_scheduler.py58
1 files changed, 40 insertions, 18 deletions
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 3a97559bf0..8665aeb50c 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus
 from synapse.util import Clock
 from synapse.util.task_scheduler import TaskScheduler
 
-from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 
-class TestTaskScheduler(unittest.HomeserverTestCase):
+class TestTaskScheduler(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.task_scheduler = hs.get_task_scheduler()
         self.task_scheduler.register_action(self._test_task, "_test_task")
@@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
 
     async def _test_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         # This test task will copy the parameters to the result
         result = None
@@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.assertIsNone(task)
 
     async def _sleeping_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         # Sleep for a second
         await deferLater(self.reactor, 1, lambda: None)
@@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
 
     def test_schedule_lot_of_tasks(self) -> None:
         """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
-        timestamp = self.clock.time_msec() + 30 * 1000
         task_ids = []
         for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
             task_ids.append(
                 self.get_success(
                     self.task_scheduler.schedule_task(
                         "_sleeping_task",
-                        timestamp=timestamp,
                         params={"val": i},
                     )
                 )
             )
 
-        # The timestamp being 30s after now the task should been executed
-        # after the first scheduling loop is run
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
-        # This is to give the time to the sleeping tasks to finish
+        # This is to give the time to the active tasks to finish
         self.reactor.advance(1)
 
         # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
@@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         )
 
         scheduled_tasks = [
-            t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
+            t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
         ]
         self.assertEquals(len(scheduled_tasks), 1)
 
+        # We need to wait for the next run of the scheduler loop
         self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
         self.reactor.advance(1)
 
@@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         )
 
     async def _raising_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         raise Exception("raising")
 
@@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         """Schedule a task raising an exception and check it runs to failure and report exception content."""
         task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
 
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
         task = self.get_success(self.task_scheduler.get_task(task_id))
         assert task is not None
         self.assertEqual(task.status, TaskStatus.FAILED)
         self.assertEqual(task.error, "raising")
 
     async def _resumable_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         if task.result and "in_progress" in task.result:
             return TaskStatus.COMPLETE, {"success": True}, None
@@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
         task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
 
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
         task = self.get_success(self.task_scheduler.get_task(task_id))
         assert task is not None
         self.assertEqual(task.status, TaskStatus.ACTIVE)
@@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.assertEqual(task.status, TaskStatus.COMPLETE)
         assert task.result is not None
         self.assertTrue(task.result.get("success"))
+
+
+class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.task_scheduler = hs.get_task_scheduler()
+        self.task_scheduler.register_action(self._test_task, "_test_task")
+
+    async def _test_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        return (TaskStatus.COMPLETE, None, None)
+
+    @override_config({"run_background_tasks_on": "worker1"})
+    def test_schedule_task(self) -> None:
+        """Check that a task scheduled to run now is launch right away on the background worker."""
+        bg_worker_hs = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={"worker_name": "worker1"},
+        )
+        bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
+
+        task_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                "_test_task",
+            )
+        )
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)