diff --git a/changelog.d/15891.feature b/changelog.d/15891.feature
new file mode 100644
index 0000000000..5024b5adc4
--- /dev/null
+++ b/changelog.d/15891.feature
@@ -0,0 +1 @@
+Implements a task scheduler for resumable potentially long running tasks.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dc79efcc14..d25e3548e0 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -91,6 +91,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
+from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -144,6 +145,7 @@ class GenericWorkerStore(
TransactionWorkerStore,
LockStore,
SessionStore,
+ TaskSchedulerWorkerStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
diff --git a/synapse/server.py b/synapse/server.py
index b72b76a38b..b1c7fedaca 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -141,6 +141,7 @@ from synapse.util.distributor import Distributor
from synapse.util.macaroons import MacaroonGenerator
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import random_string
+from synapse.util.task_scheduler import TaskScheduler
logger = logging.getLogger(__name__)
@@ -359,6 +360,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
+ self.get_task_scheduler()
def get_reactor(self) -> ISynapseReactor:
"""
@@ -912,3 +914,7 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
"""Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self)
+
+ @cache_in_self
+ def get_task_scheduler(self) -> TaskScheduler:
+ return TaskScheduler(self)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index b6028853c9..cb8fb665e4 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -70,6 +70,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamWorkerStore
from .tags import TagsStore
+from .task_scheduler import TaskSchedulerWorkerStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
@@ -127,6 +128,7 @@ class DataStore(
CacheInvalidationWorkerStore,
LockStore,
SessionStore,
+ TaskSchedulerWorkerStore,
):
def __init__(
self,
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
new file mode 100644
index 0000000000..37c2110bbd
--- /dev/null
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -0,0 +1,173 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class TaskSchedulerWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ @staticmethod
+ def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
+ row["status"] = TaskStatus(row["status"])
+ if row["params"] is not None:
+ row["params"] = json.loads(row["params"])
+ if row["result"] is not None:
+ row["result"] = json.loads(row["result"])
+ return ScheduledTask(**row)
+
+ async def get_scheduled_tasks(
+ self, action: Optional[str] = None, resource_id: Optional[str] = None
+ ) -> List[ScheduledTask]:
+ """Get a list of scheduled tasks from the DB.
+
+ If the parameters are `None` all the tasks are returned.
+
+ Args:
+ action: Limit the returned tasks to this specific action name
+ resource_id: Limit the returned tasks to this specific resource id
+
+ Returns: a list of `ScheduledTask`
+ """
+ keyvalues = {}
+ if action:
+ keyvalues["action"] = action
+ if resource_id:
+ keyvalues["resource_id"] = resource_id
+
+ rows = await self.db_pool.simple_select_list(
+ table="scheduled_tasks",
+ keyvalues=keyvalues,
+ retcols=(
+ "id",
+ "action",
+ "status",
+ "timestamp",
+ "resource_id",
+ "params",
+ "result",
+ "error",
+ ),
+ desc="get_scheduled_tasks",
+ )
+
+ return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]
+
+ async def upsert_scheduled_task(self, task: ScheduledTask) -> None:
+ """Upsert a specified `ScheduledTask` in the DB.
+
+ Args:
+ task: the `ScheduledTask` to upsert
+ """
+ await self.db_pool.simple_upsert(
+ "scheduled_tasks",
+ {"id": task.id},
+ {
+ "action": task.action,
+ "status": task.status,
+ "timestamp": task.timestamp,
+ "resource_id": task.resource_id,
+ "params": None if task.params is None else json.dumps(task.params),
+ "result": None if task.result is None else json.dumps(task.result),
+ "error": task.error,
+ },
+ desc="upsert_scheduled_task",
+ )
+
+ async def update_scheduled_task(
+ self,
+ id: str,
+ *,
+ timestamp: Optional[int] = None,
+ status: Optional[TaskStatus] = None,
+ result: Optional[JsonMapping] = None,
+ error: Optional[str] = None,
+ ) -> None:
+ """Update a scheduled task in the DB with some new value(s).
+
+ Args:
+ id: id of the `ScheduledTask` to update
+ timestamp: new timestamp of the task
+ status: new status of the task
+ result: new result of the task
+ error: new error of the task
+ """
+ updatevalues: JsonDict = {}
+ if timestamp is not None:
+ updatevalues["timestamp"] = timestamp
+ if status is not None:
+ updatevalues["status"] = status
+ if result is not None:
+ updatevalues["result"] = json.dumps(result)
+ if error is not None:
+ updatevalues["error"] = error
+ await self.db_pool.simple_update(
+ "scheduled_tasks",
+ {"id": id},
+ updatevalues,
+ desc="update_scheduled_task",
+ )
+
+ async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
+ """Get a specific `ScheduledTask` from its id.
+
+ Args:
+ id: the id of the task to retrieve
+
+ Returns: the task if available, `None` otherwise
+ """
+ row = await self.db_pool.simple_select_one(
+ table="scheduled_tasks",
+ keyvalues={"id": id},
+ retcols=(
+ "id",
+ "action",
+ "status",
+ "timestamp",
+ "resource_id",
+ "params",
+ "result",
+ "error",
+ ),
+ allow_none=True,
+ desc="get_scheduled_task",
+ )
+
+ return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+
+ async def delete_scheduled_task(self, id: str) -> None:
+ """Delete a specific task from its id.
+
+ Args:
+ id: the id of the task to delete
+ """
+ await self.db_pool.simple_delete(
+ "scheduled_tasks",
+ keyvalues={"id": id},
+ desc="delete_scheduled_task",
+ )
diff --git a/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql b/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql
new file mode 100644
index 0000000000..4ee43887b6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql
@@ -0,0 +1,26 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cf ScheduledTask docstring for the meaning of the fields.
+CREATE TABLE IF NOT EXISTS scheduled_tasks(
+ id text PRIMARY KEY,
+ action text NOT NULL,
+ status text NOT NULL,
+ timestamp bigint NOT NULL,
+ resource_id text,
+ params text,
+ result text,
+ error text
+);
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 095be070e0..7effac8c1d 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -15,6 +15,7 @@
import abc
import re
import string
+from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -979,3 +980,40 @@ class UserProfile(TypedDict):
class RetentionPolicy:
min_lifetime: Optional[int] = None
max_lifetime: Optional[int] = None
+
+
+class TaskStatus(str, Enum):
+ """Status of a scheduled task"""
+
+ # Task is scheduled but not active
+ SCHEDULED = "scheduled"
+ # Task is active and probably running, and if not
+ # will be run on next scheduler loop run
+ ACTIVE = "active"
+ # Task has completed successfully
+ COMPLETE = "complete"
+ # Task is over and either returned a failed status, or had an exception
+ FAILED = "failed"
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class ScheduledTask:
+ """Description of a scheduled task"""
+
+ # id used to identify the task
+ id: str
+ # name of the action to be run by this task
+ action: str
+ # current status of this task
+ status: TaskStatus
+ # if the status is SCHEDULED then this represents when it should be launched,
+ # otherwise it represents the last time this task got a change of state
+ timestamp: int
+ # Optionally bind a task to some resource id for easy retrieval
+ resource_id: Optional[str]
+ # Optional parameters that will be passed to the function ran by the task
+ params: Optional[JsonMapping]
+ # Optional result that can be updated by the running task
+ result: Optional[JsonMapping]
+ # Optional error that should be assigned a value when the status is FAILED
+ error: Optional[str]
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
new file mode 100644
index 0000000000..2c3abaaedf
--- /dev/null
+++ b/synapse/util/task_scheduler.py
@@ -0,0 +1,244 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
+
+from twisted.python.failure import Failure
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class TaskScheduler:
+ # Precision of the scheduler, evaluation of tasks to run will only happen
+ # every `SCHEDULE_INTERVAL_MS` ms
+ SCHEDULE_INTERVAL_MS = 5 * 60 * 1000 # 5mn
+ # Time before a complete or failed task is deleted from the DB
+ KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
+ self.running_tasks: Set[str] = set()
+ self.actions: Dict[
+ str,
+ Callable[
+ [ScheduledTask, bool],
+ Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+ ],
+ ] = {}
+ self.run_background_tasks = hs.config.worker.run_background_tasks
+
+ if self.run_background_tasks:
+ self.clock.looping_call(
+ run_as_background_process,
+ TaskScheduler.SCHEDULE_INTERVAL_MS,
+ "scheduled_tasks_loop",
+ self._scheduled_tasks_loop,
+ )
+
+ def register_action(
+ self,
+ function: Callable[
+ [ScheduledTask, bool],
+ Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+ ],
+ action_name: str,
+ ) -> None:
+ """Register a function to be executed when an action is scheduled with
+ the specified action name.
+
+ Actions need to be registered as early as possible so that a resumed action
+ can find its matching function. It's usually better to NOT do that right before
+ calling `schedule_task` but rather in an `__init__` method.
+
+ Args:
+ function: The function to be executed for this action. The parameters
+ passed to the function when launched are the `ScheduledTask` being run,
+ and a `first_launch` boolean to signal if it's a resumed task or the first
+ launch of it. The function should return a tuple of new `status`, `result`
+ and `error` as specified in `ScheduledTask`.
+ action_name: The name of the action to be associated with the function
+ """
+ self.actions[action_name] = function
+
+ async def schedule_task(
+ self,
+ action: str,
+ *,
+ resource_id: Optional[str] = None,
+ timestamp: Optional[int] = None,
+ params: Optional[JsonMapping] = None,
+ ) -> str:
+ """Schedule a new potentially resumable task. A function matching the specified
+ `action` should have been previously registered with `register_action`.
+
+ Args:
+ action: the name of a previously registered action
+ resource_id: a task can be associated with a resource id to facilitate
+ getting all tasks associated with a specific resource
+ timestamp: if `None`, the task will be launched immediately, otherwise it
+ will be launch after the `timestamp` value. Note that this scheduler
+ is not meant to be precise, and the scheduling could be delayed if
+ too many tasks are already running
+ params: a set of parameters that can be easily accessed from inside the
+ executed function
+
+ Returns: the id of the scheduled task
+ """
+ if action not in self.actions:
+ raise Exception(
+ f"No function associated with the action {action} of the scheduled task"
+ )
+
+ launch_now = False
+ if timestamp is None or timestamp < self.clock.time_msec():
+ timestamp = self.clock.time_msec()
+ launch_now = True
+
+ task = ScheduledTask(
+ random_string(16),
+ action,
+ TaskStatus.SCHEDULED,
+ timestamp,
+ resource_id,
+ params,
+ None,
+ None,
+ )
+ await self.store.upsert_scheduled_task(task)
+
+ if launch_now and self.run_background_tasks:
+ await self._launch_task(task, True)
+
+ return task.id
+
+ async def update_task(
+ self,
+ id: str,
+ *,
+ status: Optional[TaskStatus] = None,
+ result: Optional[JsonMapping] = None,
+ error: Optional[str] = None,
+ ) -> None:
+ """Update some task associated values.
+
+ This is used internally, and also exposed publically so it can be used inside task functions.
+ This allows to store in DB the progress of a task so it can be resumed properly after a restart of synapse.
+
+ Args:
+ id: the id of the task to update
+ status: the new `TaskStatus` of the task
+ result: the new result of the task
+ error: the new error of the task
+ """
+ await self.store.update_scheduled_task(
+ id,
+ timestamp=self.clock.time_msec(),
+ status=status,
+ result=result,
+ error=error,
+ )
+
+ async def get_task(self, id: str) -> Optional[ScheduledTask]:
+ """Get a specific task description by id.
+
+ Args:
+ id: the id of the task to retrieve
+
+ Returns: the task description or `None` if it doesn't exist
+ or it has already been cleaned
+ """
+ return await self.store.get_scheduled_task(id)
+
+ async def get_tasks(
+ self, action: str, resource_id: Optional[str]
+ ) -> List[ScheduledTask]:
+ """Get a list of tasks associated with an action name, and
+ optionally with a resource id.
+
+ Args:
+ action: the action name of the tasks to retrieve
+ resource_id: if `None`, returns all associated tasks for
+ the specified action name, regardless of the resource id
+
+ Returns: a list of `ScheduledTask`
+ """
+ return await self.store.get_scheduled_tasks(action, resource_id)
+
+ async def _scheduled_tasks_loop(self) -> None:
+ """Main loop taking care of launching the scheduled tasks when needed."""
+ for task in await self.store.get_scheduled_tasks():
+ if task.id not in self.running_tasks:
+ if (
+ task.status == TaskStatus.SCHEDULED
+ and task.timestamp < self.clock.time_msec()
+ ):
+ await self._launch_task(task, True)
+ elif task.status == TaskStatus.ACTIVE:
+ await self._launch_task(task, False)
+ elif (
+ task.status == TaskStatus.COMPLETE
+ or task.status == TaskStatus.FAILED
+ ) and self.clock.time_msec() > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS:
+ await self.store.delete_scheduled_task(task.id)
+
+ async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None:
+ """Launch a scheduled task now.
+
+ Args:
+ task: the task to launch
+ first_launch: `True` if it's the first time is launched, `False` otherwise
+ """
+ if task.action not in self.actions:
+ raise Exception(
+ f"No function associated with the action {task.action} of the scheduled task"
+ )
+
+ function = self.actions[task.action]
+
+ async def wrapper() -> None:
+ try:
+ (status, result, error) = await function(task, first_launch)
+ except Exception:
+ f = Failure()
+ logger.error(
+ f"scheduled task {task.id} failed",
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
+ status = TaskStatus.FAILED
+ result = None
+ error = f.getErrorMessage()
+
+ await self.update_task(
+ task.id,
+ status=status,
+ result=result,
+ error=error,
+ )
+ self.running_tasks.remove(task.id)
+
+ await self.update_task(task.id, status=TaskStatus.ACTIVE)
+ self.running_tasks.add(task.id)
+ description = task.action
+ if task.resource_id:
+ description += f"-{task.resource_id}"
+ run_as_background_process(description, wrapper)
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
new file mode 100644
index 0000000000..9fecc6b530
--- /dev/null
+++ b/tests/util/test_task_scheduler.py
@@ -0,0 +1,132 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+from twisted.internet.task import deferLater
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
+
+from tests import unittest
+
+
+class TestTaskScheduler(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.task_scheduler = hs.get_task_scheduler()
+ self.task_scheduler.register_action(self._test_task, "_test_task")
+ self.task_scheduler.register_action(self._raising_task, "_raising_task")
+ self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
+
+ async def _test_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ # This test task will copy the parameters to the result
+ result = None
+ if task.params:
+ result = task.params
+ return (TaskStatus.COMPLETE, result, None)
+
+ def test_schedule_task(self) -> None:
+ """Schedule a task in the future with some parameters to be copied as a result and check it executed correctly.
+ Also check that it get removed after `KEEP_TASKS_FOR_MS`."""
+ timestamp = self.clock.time_msec() + 2 * 60 * 1000
+ task_id = self.get_success(
+ self.task_scheduler.schedule_task(
+ "_test_task",
+ timestamp=timestamp,
+ params={"val": 1},
+ )
+ )
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.SCHEDULED)
+ self.assertIsNone(task.result)
+
+ # The timestamp being 2mn after now the task should been executed
+ # after the first scheduling loop is run
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) + 1)
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
+ assert task.result is not None
+ # The passed parameter should have been copied to the result
+ self.assertTrue(task.result.get("val") == 1)
+
+ # Let's wait for the complete task to be deleted and hence unavailable
+ self.reactor.advance((TaskScheduler.KEEP_TASKS_FOR_MS / 1000) + 1)
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ self.assertIsNone(task)
+
+ def test_schedule_task_now(self) -> None:
+ """Schedule a task now and check it runs fine to completion."""
+ task_id = self.get_success(
+ self.task_scheduler.schedule_task("_test_task", params={"val": 1})
+ )
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
+ assert task.result is not None
+ self.assertTrue(task.result.get("val") == 1)
+
+ async def _raising_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ raise Exception("raising")
+
+ def test_schedule_raising_task_now(self) -> None:
+ """Schedule a task raising an exception and check it runs to failure and report exception content."""
+ task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.FAILED)
+ self.assertEqual(task.error, "raising")
+
+ async def _resumable_task(
+ self, task: ScheduledTask, first_launch: bool
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ if task.result and "in_progress" in task.result:
+ return TaskStatus.COMPLETE, {"success": True}, None
+ else:
+ await self.task_scheduler.update_task(task.id, result={"in_progress": True})
+ # Await forever to simulate an aborted task because of a restart
+ await deferLater(self.reactor, 2**16, None)
+ # This should never been called
+ return TaskStatus.ACTIVE, None, None
+
+ def test_schedule_resumable_task_now(self) -> None:
+ """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
+ task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.ACTIVE)
+
+ # Simulate a synapse restart by emptying the list of running tasks
+ self.task_scheduler.running_tasks = set()
+ self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) + 1)
+
+ task = self.get_success(self.task_scheduler.get_task(task_id))
+ assert task is not None
+ self.assertEqual(task.status, TaskStatus.COMPLETE)
+ assert task.result is not None
+ self.assertTrue(task.result.get("success"))
|