summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15891.feature1
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/server.py7
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/task_scheduler.py202
-rw-r--r--synapse/storage/schema/__init__.py1
-rw-r--r--synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql28
-rw-r--r--synapse/types/__init__.py39
-rw-r--r--synapse/util/task_scheduler.py364
-rw-r--r--tests/util/test_task_scheduler.py186
10 files changed, 831 insertions, 1 deletions
diff --git a/changelog.d/15891.feature b/changelog.d/15891.feature
new file mode 100644
index 0000000000..5024b5adc4
--- /dev/null
+++ b/changelog.d/15891.feature
@@ -0,0 +1 @@
+Implements a task scheduler for	resumable potentially long running tasks.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dc79efcc14..d25e3548e0 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -91,6 +91,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
+from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
 from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -144,6 +145,7 @@ class GenericWorkerStore(
     TransactionWorkerStore,
     LockStore,
     SessionStore,
+    TaskSchedulerWorkerStore,
 ):
     # Properties that multiple storage classes define. Tell mypy what the
     # expected type is.
diff --git a/synapse/server.py b/synapse/server.py
index e753ff0377..7cdd3ea3c2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -142,6 +142,7 @@ from synapse.util.distributor import Distributor
 from synapse.util.macaroons import MacaroonGenerator
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import random_string
+from synapse.util.task_scheduler import TaskScheduler
 
 logger = logging.getLogger(__name__)
 
@@ -360,6 +361,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
             getattr(self, "get_" + i + "_handler")()
+        self.get_task_scheduler()
 
     def get_reactor(self) -> ISynapseReactor:
         """
@@ -912,6 +914,9 @@ class HomeServer(metaclass=abc.ABCMeta):
         """Usage metrics shared between phone home stats and the prometheus exporter."""
         return CommonUsageMetricsManager(self)
 
-    @cache_in_self
     def get_worker_locks_handler(self) -> WorkerLocksHandler:
         return WorkerLocksHandler(self)
+
+    @cache_in_self
+    def get_task_scheduler(self) -> TaskScheduler:
+        return TaskScheduler(self)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index e17f25e87a..a85633efcd 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -70,6 +70,7 @@ from .state import StateStore
 from .stats import StatsStore
 from .stream import StreamWorkerStore
 from .tags import TagsStore
+from .task_scheduler import TaskSchedulerWorkerStore
 from .transactions import TransactionWorkerStore
 from .ui_auth import UIAuthStore
 from .user_directory import UserDirectoryStore
@@ -127,6 +128,7 @@ class DataStore(
     CacheInvalidationWorkerStore,
     LockStore,
     SessionStore,
+    TaskSchedulerWorkerStore,
 ):
     def __init__(
         self,
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
new file mode 100644
index 0000000000..1fb3180c3c
--- /dev/null
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
+from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class TaskSchedulerWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+    @staticmethod
+    def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
+        row["status"] = TaskStatus(row["status"])
+        if row["params"] is not None:
+            row["params"] = db_to_json(row["params"])
+        if row["result"] is not None:
+            row["result"] = db_to_json(row["result"])
+        return ScheduledTask(**row)
+
+    async def get_scheduled_tasks(
+        self,
+        *,
+        actions: Optional[List[str]] = None,
+        resource_id: Optional[str] = None,
+        statuses: Optional[List[TaskStatus]] = None,
+        max_timestamp: Optional[int] = None,
+    ) -> List[ScheduledTask]:
+        """Get a list of scheduled tasks from the DB.
+
+        Args:
+            actions: Limit the returned tasks to those specific action names
+            resource_id: Limit the returned tasks to the specific resource id, if specified
+            statuses: Limit the returned tasks to the specific statuses
+            max_timestamp: Limit the returned tasks to the ones that have
+                a timestamp inferior to the specified one
+
+        Returns: a list of `ScheduledTask`, ordered by increasing timestamps
+        """
+
+        def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+            clauses: List[str] = []
+            args: List[Any] = []
+            if resource_id:
+                clauses.append("resource_id = ?")
+                args.append(resource_id)
+            if actions is not None:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "action", actions
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+            if statuses is not None:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "status", statuses
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+            if max_timestamp is not None:
+                clauses.append("timestamp <= ?")
+                args.append(max_timestamp)
+
+            sql = "SELECT * FROM scheduled_tasks"
+            if clauses:
+                sql = sql + " WHERE " + " AND ".join(clauses)
+
+            sql = sql + "ORDER BY timestamp"
+
+            txn.execute(sql, args)
+            return self.db_pool.cursor_to_dict(txn)
+
+        rows = await self.db_pool.runInteraction(
+            "get_scheduled_tasks", get_scheduled_tasks_txn
+        )
+        return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]
+
+    async def insert_scheduled_task(self, task: ScheduledTask) -> None:
+        """Insert a specified `ScheduledTask` in the DB.
+
+        Args:
+            task: the `ScheduledTask` to insert
+        """
+        await self.db_pool.simple_insert(
+            "scheduled_tasks",
+            {
+                "id": task.id,
+                "action": task.action,
+                "status": task.status,
+                "timestamp": task.timestamp,
+                "resource_id": task.resource_id,
+                "params": None
+                if task.params is None
+                else json_encoder.encode(task.params),
+                "result": None
+                if task.result is None
+                else json_encoder.encode(task.result),
+                "error": task.error,
+            },
+            desc="insert_scheduled_task",
+        )
+
+    async def update_scheduled_task(
+        self,
+        id: str,
+        timestamp: int,
+        *,
+        status: Optional[TaskStatus] = None,
+        result: Optional[JsonMapping] = None,
+        error: Optional[str] = None,
+    ) -> bool:
+        """Update a scheduled task in the DB with some new value(s).
+
+        Args:
+            id: id of the `ScheduledTask` to update
+            timestamp: new timestamp of the task
+            status: new status of the task
+            result: new result of the task
+            error: new error of the task
+
+        Returns: `False` if no matching row was found, `True` otherwise
+        """
+        updatevalues: JsonDict = {"timestamp": timestamp}
+        if status is not None:
+            updatevalues["status"] = status
+        if result is not None:
+            updatevalues["result"] = json_encoder.encode(result)
+        if error is not None:
+            updatevalues["error"] = error
+        nb_rows = await self.db_pool.simple_update(
+            "scheduled_tasks",
+            {"id": id},
+            updatevalues,
+            desc="update_scheduled_task",
+        )
+        return nb_rows > 0
+
+    async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
+        """Get a specific `ScheduledTask` from its id.
+
+        Args:
+            id: the id of the task to retrieve
+
+        Returns: the task if available, `None` otherwise
+        """
+        row = await self.db_pool.simple_select_one(
+            table="scheduled_tasks",
+            keyvalues={"id": id},
+            retcols=(
+                "id",
+                "action",
+                "status",
+                "timestamp",
+                "resource_id",
+                "params",
+                "result",
+                "error",
+            ),
+            allow_none=True,
+            desc="get_scheduled_task",
+        )
+
+        return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+
+    async def delete_scheduled_task(self, id: str) -> None:
+        """Delete a specific task from its id.
+
+        Args:
+            id: the id of the task to delete
+        """
+        await self.db_pool.simple_delete(
+            "scheduled_tasks",
+            keyvalues={"id": id},
+            desc="delete_scheduled_task",
+        )
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 7de9949a5b..649d3c8e9f 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -113,6 +113,7 @@ Changes in SCHEMA_VERSION = 79
 
 Changes in SCHEMA_VERSION = 80
     - The event_txn_id_device_id is always written to for new events.
+    - Add tables for the task scheduler.
 """
 
 
diff --git a/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql
new file mode 100644
index 0000000000..286d109ed7
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql
@@ -0,0 +1,28 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cf ScheduledTask docstring for the meaning of the fields.
+CREATE TABLE IF NOT EXISTS scheduled_tasks(
+    id TEXT PRIMARY KEY,
+    action TEXT NOT NULL,
+    status TEXT NOT NULL,
+    timestamp BIGINT NOT NULL,
+    resource_id TEXT,
+    params TEXT,
+    result TEXT,
+    error TEXT
+);
+
+CREATE INDEX IF NOT EXISTS scheduled_tasks_status ON scheduled_tasks(status);
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 073f682aca..e750417189 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -15,6 +15,7 @@
 import abc
 import re
 import string
+from enum import Enum
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -969,3 +970,41 @@ class UserProfile(TypedDict):
 class RetentionPolicy:
     min_lifetime: Optional[int] = None
     max_lifetime: Optional[int] = None
+
+
+class TaskStatus(str, Enum):
+    """Status of a scheduled task"""
+
+    # Task is scheduled but not active
+    SCHEDULED = "scheduled"
+    # Task is active and probably running, and if not
+    # will be run on next scheduler loop run
+    ACTIVE = "active"
+    # Task has completed successfully
+    COMPLETE = "complete"
+    # Task is over and either returned a failed status, or had an exception
+    FAILED = "failed"
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class ScheduledTask:
+    """Description of a scheduled task"""
+
+    # Id used to identify the task
+    id: str
+    # Name of the action to be run by this task
+    action: str
+    # Current status of this task
+    status: TaskStatus
+    # If the status is SCHEDULED then this represents when it should be launched,
+    # otherwise it represents the last time this task got a change of state.
+    # In milliseconds since epoch in system time timezone, usually UTC.
+    timestamp: int
+    # Optionally bind a task to some resource id for easy retrieval
+    resource_id: Optional[str]
+    # Optional parameters that will be passed to the function ran by the task
+    params: Optional[JsonMapping]
+    # Optional result that can be updated by the running task
+    result: Optional[JsonMapping]
+    # Optional error that should be assigned a value when the status is FAILED
+    error: Optional[str]
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
new file mode 100644
index 0000000000..773a8327f6
--- /dev/null
+++ b/synapse/util/task_scheduler.py
@@ -0,0 +1,364 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
+
+from prometheus_client import Gauge
+
+from twisted.python.failure import Failure
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+running_tasks_gauge = Gauge(
+    "synapse_scheduler_running_tasks",
+    "The number of concurrent running tasks handled by the TaskScheduler",
+)
+
+
+class TaskScheduler:
+    """
+    This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
+    to launch a background task, or Twisted `deferLater` if we want to do so later on.
+
+    The problem with that is that the tasks will just stop and never be resumed if synapse
+    is stopped for whatever reason.
+
+    How this works:
+    - A function mapped to a named action should first be registered with `register_action`.
+    This function will be called when trying to resuming tasks after a synapse shutdown,
+    so this registration should happen when synapse is initialised, NOT right before scheduling
+    a task.
+    - A task can then be launched using this named action with `schedule_task`. A `params` dict
+    can be passed, and it will be available to the registered function when launched. This task
+    can be launch either now-ish, or later on by giving a `timestamp` parameter.
+
+    The function may call `update_task` at any time to update the `result` of the task,
+    and this can be used to resume the task at a specific point and/or to convey a result to
+    the code launching the task.
+    You can also specify the `result` (and/or an `error`) when returning from the function.
+
+    The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting
+    to launch now, the launch will still not happen before the next loop run.
+
+    Tasks will be run on the worker specified with `run_background_tasks_on` config,
+    or the main one by default.
+    There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already
+    full. In this regard, please take great care that scheduled tasks can actually finished.
+    For now there is no mechanism to stop a running task if it is stuck.
+    """
+
+    # Precision of the scheduler, evaluation of tasks to run will only happen
+    # every `SCHEDULE_INTERVAL_MS` ms
+    SCHEDULE_INTERVAL_MS = 1 * 60 * 1000  # 1mn
+    # Time before a complete or failed task is deleted from the DB
+    KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000  # 1 week
+    # Maximum number of tasks that can run at the same time
+    MAX_CONCURRENT_RUNNING_TASKS = 10
+    # Time from the last task update after which we will log a warning
+    LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
+
+    def __init__(self, hs: "HomeServer"):
+        self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
+        self._running_tasks: Set[str] = set()
+        # A map between action names and their registered function
+        self._actions: Dict[
+            str,
+            Callable[
+                [ScheduledTask, bool],
+                Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+            ],
+        ] = {}
+        self._run_background_tasks = hs.config.worker.run_background_tasks
+
+        if self._run_background_tasks:
+            self._clock.looping_call(
+                run_as_background_process,
+                TaskScheduler.SCHEDULE_INTERVAL_MS,
+                "handle_scheduled_tasks",
+                self._handle_scheduled_tasks,
+            )
+
+    def register_action(
+        self,
+        function: Callable[
+            [ScheduledTask, bool],
+            Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
+        ],
+        action_name: str,
+    ) -> None:
+        """Register a function to be executed when an action is scheduled with
+        the specified action name.
+
+        Actions need to be registered as early as possible so that a resumed action
+        can find its matching function. It's usually better to NOT do that right before
+        calling `schedule_task` but rather in an `__init__` method.
+
+        Args:
+            function: The function to be executed for this action. The parameters
+                passed to the function when launched are the `ScheduledTask` being run,
+                and a `first_launch` boolean to signal if it's a resumed task or the first
+                launch of it. The function should return a tuple of new `status`, `result`
+                and `error` as specified in `ScheduledTask`.
+            action_name: The name of the action to be associated with the function
+        """
+        self._actions[action_name] = function
+
+    async def schedule_task(
+        self,
+        action: str,
+        *,
+        resource_id: Optional[str] = None,
+        timestamp: Optional[int] = None,
+        params: Optional[JsonMapping] = None,
+    ) -> str:
+        """Schedule a new potentially resumable task. A function matching the specified
+        `action` should have been previously registered with `register_action`.
+
+        Args:
+            action: the name of a previously registered action
+            resource_id: a task can be associated with a resource id to facilitate
+                getting all tasks associated with a specific resource
+            timestamp: if `None`, the task will be launched as soon as possible, otherwise it
+                will be launch as soon as possible after the `timestamp` value.
+                Note that this scheduler is not meant to be precise, and the scheduling
+                could be delayed if too many tasks are already running
+            params: a set of parameters that can be easily accessed from inside the
+                executed function
+
+        Returns:
+            The id of the scheduled task
+        """
+        if action not in self._actions:
+            raise Exception(
+                f"No function associated with action {action} of the scheduled task"
+            )
+
+        if timestamp is None or timestamp < self._clock.time_msec():
+            timestamp = self._clock.time_msec()
+
+        task = ScheduledTask(
+            random_string(16),
+            action,
+            TaskStatus.SCHEDULED,
+            timestamp,
+            resource_id,
+            params,
+            result=None,
+            error=None,
+        )
+        await self._store.insert_scheduled_task(task)
+
+        return task.id
+
+    async def update_task(
+        self,
+        id: str,
+        *,
+        timestamp: Optional[int] = None,
+        status: Optional[TaskStatus] = None,
+        result: Optional[JsonMapping] = None,
+        error: Optional[str] = None,
+    ) -> bool:
+        """Update some task associated values. This is exposed publically so it can
+        be used inside task functions, mainly to update the result and be able to
+        resume a task at a specific step after a restart of synapse.
+
+        It can also be used to stage a task, by setting the `status` to `SCHEDULED` with
+        a new timestamp.
+
+        The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED`
+        are terminal status and can only be set by returning it in the function.
+
+        Args:
+            id: the id of the task to update
+            timestamp: useful to schedule a new stage of the task at a later date
+            status: the new `TaskStatus` of the task
+            result: the new result of the task
+            error: the new error of the task
+        """
+        if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED:
+            raise Exception(
+                "update_task can't be called with a FAILED or COMPLETE status"
+            )
+
+        if timestamp is None:
+            timestamp = self._clock.time_msec()
+        return await self._store.update_scheduled_task(
+            id,
+            timestamp,
+            status=status,
+            result=result,
+            error=error,
+        )
+
+    async def get_task(self, id: str) -> Optional[ScheduledTask]:
+        """Get a specific task description by id.
+
+        Args:
+            id: the id of the task to retrieve
+
+        Returns:
+            The task information or `None` if it doesn't exist or it has
+            already been removed because it's too old.
+        """
+        return await self._store.get_scheduled_task(id)
+
+    async def get_tasks(
+        self,
+        *,
+        actions: Optional[List[str]] = None,
+        resource_id: Optional[str] = None,
+        statuses: Optional[List[TaskStatus]] = None,
+        max_timestamp: Optional[int] = None,
+    ) -> List[ScheduledTask]:
+        """Get a list of tasks. Returns all the tasks if no args is provided.
+
+        If an arg is `None` all tasks matching the other args will be selected.
+        If an arg is an empty list, the corresponding value of the task needs
+        to be `None` to be selected.
+
+        Args:
+            actions: Limit the returned tasks to those specific action names
+            resource_id: Limit the returned tasks to the specific resource id, if specified
+            statuses: Limit the returned tasks to the specific statuses
+            max_timestamp: Limit the returned tasks to the ones that have
+                a timestamp inferior to the specified one
+
+        Returns
+            A list of `ScheduledTask`, ordered by increasing timestamps
+        """
+        return await self._store.get_scheduled_tasks(
+            actions=actions,
+            resource_id=resource_id,
+            statuses=statuses,
+            max_timestamp=max_timestamp,
+        )
+
+    async def delete_task(self, id: str) -> None:
+        """Delete a task. Running tasks can't be deleted.
+
+        Can only be called from the worker handling the task scheduling.
+
+        Args:
+            id: id of the task to delete
+        """
+        if self.task_is_running(id):
+            raise Exception(f"Task {id} is currently running and can't be deleted")
+        await self._store.delete_scheduled_task(id)
+
+    def task_is_running(self, id: str) -> bool:
+        """Check if a task is currently running.
+
+        Can only be called from the worker handling the task scheduling.
+
+        Args:
+            id: id of the task to check
+        """
+        assert self._run_background_tasks
+        return id in self._running_tasks
+
+    async def _handle_scheduled_tasks(self) -> None:
+        """Main loop taking care of launching tasks and cleaning up old ones."""
+        await self._launch_scheduled_tasks()
+        await self._clean_scheduled_tasks()
+
+    async def _launch_scheduled_tasks(self) -> None:
+        """Retrieve and launch scheduled tasks that should be running at that time."""
+        for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
+            if not self.task_is_running(task.id):
+                if (
+                    len(self._running_tasks)
+                    < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
+                ):
+                    await self._launch_task(task, first_launch=False)
+            else:
+                if (
+                    self._clock.time_msec()
+                    > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
+                ):
+                    logger.warn(
+                        f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
+                    )
+        for task in await self.get_tasks(
+            statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
+        ):
+            if (
+                not self.task_is_running(task.id)
+                and len(self._running_tasks)
+                < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
+            ):
+                await self._launch_task(task, first_launch=True)
+
+        running_tasks_gauge.set(len(self._running_tasks))
+
+    async def _clean_scheduled_tasks(self) -> None:
+        """Clean old complete or failed jobs to avoid clutter the DB."""
+        for task in await self._store.get_scheduled_tasks(
+            statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
+        ):
+            # FAILED and COMPLETE tasks should never be running
+            assert not self.task_is_running(task.id)
+            if (
+                self._clock.time_msec()
+                > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
+            ):
+                await self._store.delete_scheduled_task(task.id)
+
+    async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None:
+        """Launch a scheduled task now.
+
+        Args:
+            task: the task to launch
+            first_launch: `True` if it's the first time is launched, `False` otherwise
+        """
+        assert task.action in self._actions
+
+        function = self._actions[task.action]
+
+        async def wrapper() -> None:
+            try:
+                (status, result, error) = await function(task, first_launch)
+            except Exception:
+                f = Failure()
+                logger.error(
+                    f"scheduled task {task.id} failed",
+                    exc_info=(f.type, f.value, f.getTracebackObject()),
+                )
+                status = TaskStatus.FAILED
+                result = None
+                error = f.getErrorMessage()
+
+            await self._store.update_scheduled_task(
+                task.id,
+                self._clock.time_msec(),
+                status=status,
+                result=result,
+                error=error,
+            )
+            self._running_tasks.remove(task.id)
+
+        self._running_tasks.add(task.id)
+        await self.update_task(task.id, status=TaskStatus.ACTIVE)
+        description = f"{task.id}-{task.action}"
+        run_as_background_process(description, wrapper)
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
new file mode 100644
index 0000000000..3a97559bf0
--- /dev/null
+++ b/tests/util/test_task_scheduler.py
@@ -0,0 +1,186 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+from twisted.internet.task import deferLater
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
+
+from tests import unittest
+
+
+class TestTaskScheduler(unittest.HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.task_scheduler = hs.get_task_scheduler()
+        self.task_scheduler.register_action(self._test_task, "_test_task")
+        self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task")
+        self.task_scheduler.register_action(self._raising_task, "_raising_task")
+        self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
+
+    async def _test_task(
+        self, task: ScheduledTask, first_launch: bool
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        # This test task will copy the parameters to the result
+        result = None
+        if task.params:
+            result = task.params
+        return (TaskStatus.COMPLETE, result, None)
+
+    def test_schedule_task(self) -> None:
+        """Schedule a task in the future with some parameters to be copied as a result and check it executed correctly.
+        Also check that it get removed after `KEEP_TASKS_FOR_MS`."""
+        timestamp = self.clock.time_msec() + 30 * 1000
+        task_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                "_test_task",
+                timestamp=timestamp,
+                params={"val": 1},
+            )
+        )
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.SCHEDULED)
+        self.assertIsNone(task.result)
+
+        # The timestamp being 30s after now the task should been executed
+        # after the first scheduling loop is run
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)
+        assert task.result is not None
+        # The passed parameter should have been copied to the result
+        self.assertTrue(task.result.get("val") == 1)
+
+        # Let's wait for the complete task to be deleted and hence unavailable
+        self.reactor.advance((TaskScheduler.KEEP_TASKS_FOR_MS / 1000) + 1)
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        self.assertIsNone(task)
+
+    async def _sleeping_task(
+        self, task: ScheduledTask, first_launch: bool
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        # Sleep for a second
+        await deferLater(self.reactor, 1, lambda: None)
+        return TaskStatus.COMPLETE, None, None
+
+    def test_schedule_lot_of_tasks(self) -> None:
+        """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
+        timestamp = self.clock.time_msec() + 30 * 1000
+        task_ids = []
+        for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
+            task_ids.append(
+                self.get_success(
+                    self.task_scheduler.schedule_task(
+                        "_sleeping_task",
+                        timestamp=timestamp,
+                        params={"val": i},
+                    )
+                )
+            )
+
+        # The timestamp being 30s after now the task should been executed
+        # after the first scheduling loop is run
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+        # This is to give the time to the sleeping tasks to finish
+        self.reactor.advance(1)
+
+        # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
+        # is still scheduled.
+        tasks = [
+            self.get_success(self.task_scheduler.get_task(task_id))
+            for task_id in task_ids
+        ]
+
+        self.assertEquals(
+            len(
+                [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
+            ),
+            TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
+        )
+
+        scheduled_tasks = [
+            t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
+        ]
+        self.assertEquals(len(scheduled_tasks), 1)
+
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+        self.reactor.advance(1)
+
+        # Check that the last task has been properly executed after the next scheduler loop run
+        prev_scheduled_task = self.get_success(
+            self.task_scheduler.get_task(scheduled_tasks[0].id)
+        )
+        assert prev_scheduled_task is not None
+        self.assertEquals(
+            prev_scheduled_task.status,
+            TaskStatus.COMPLETE,
+        )
+
+    async def _raising_task(
+        self, task: ScheduledTask, first_launch: bool
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        raise Exception("raising")
+
+    def test_schedule_raising_task(self) -> None:
+        """Schedule a task raising an exception and check it runs to failure and report exception content."""
+        task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
+
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.FAILED)
+        self.assertEqual(task.error, "raising")
+
+    async def _resumable_task(
+        self, task: ScheduledTask, first_launch: bool
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        if task.result and "in_progress" in task.result:
+            return TaskStatus.COMPLETE, {"success": True}, None
+        else:
+            await self.task_scheduler.update_task(task.id, result={"in_progress": True})
+            # Await forever to simulate an aborted task because of a restart
+            await deferLater(self.reactor, 2**16, lambda: None)
+            # This should never been called
+            return TaskStatus.ACTIVE, None, None
+
+    def test_schedule_resumable_task(self) -> None:
+        """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
+        task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
+
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.ACTIVE)
+
+        # Simulate a synapse restart by emptying the list of running tasks
+        self.task_scheduler._running_tasks = set()
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)
+        assert task.result is not None
+        self.assertTrue(task.result.get("success"))