summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py25
-rw-r--r--synapse/util/caches/dictionary_cache.py10
-rw-r--r--synapse/util/caches/expiringcache.py24
-rw-r--r--synapse/util/caches/ttlcache.py10
-rw-r--r--synapse/util/gai_resolver.py2
-rw-r--r--synapse/util/task_scheduler.py158
6 files changed, 132 insertions, 97 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 943ad54456..0cbeb0c365 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -19,6 +19,7 @@ import collections
 import inspect
 import itertools
 import logging
+import typing
 from contextlib import asynccontextmanager
 from typing import (
     Any,
@@ -29,6 +30,7 @@ from typing import (
     Collection,
     Coroutine,
     Dict,
+    Generator,
     Generic,
     Hashable,
     Iterable,
@@ -398,7 +400,7 @@ class _LinearizerEntry:
     # The number of things executing.
     count: int
     # Deferreds for the things blocked from executing.
-    deferreds: collections.OrderedDict
+    deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]]
 
 
 class Linearizer:
@@ -717,30 +719,25 @@ def timeout_deferred(
     return new_d
 
 
-# This class can't be generic because it uses slots with attrs.
-# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class DoneAwaitable:  # should be: Generic[R]
+class DoneAwaitable(Awaitable[R]):
     """Simple awaitable that returns the provided value."""
 
-    value: Any  # should be: R
+    value: R
 
-    def __await__(self) -> Any:
-        return self
-
-    def __iter__(self) -> "DoneAwaitable":
-        return self
-
-    def __next__(self) -> None:
-        raise StopIteration(self.value)
+    def __await__(self) -> Generator[Any, None, R]:
+        yield None
+        return self.value
 
 
 def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
     """Convert a value to an awaitable if not already an awaitable."""
     if inspect.isawaitable(value):
-        assert isinstance(value, Awaitable)
         return value
 
+    # For some reason mypy doesn't deduce that value is not Awaitable here, even though
+    # inspect.isawaitable returns a TypeGuard.
+    assert not isinstance(value, Awaitable)
     return DoneAwaitable(value)
 
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 5eaf70c7ab..2fbc7b1e6c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,7 +14,7 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
+from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
 from typing_extensions import Literal
@@ -33,10 +33,8 @@ DKT = TypeVar("DKT")
 DV = TypeVar("DV")
 
 
-# This class can't be generic because it uses slots with attrs.
-# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class DictionaryEntry:  # should be: Generic[DKT, DV].
+class DictionaryEntry(Generic[DKT, DV]):
     """Returned when getting an entry from the cache
 
     If `full` is true then `known_absent` will be the empty set.
@@ -50,8 +48,8 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
     """
 
     full: bool
-    known_absent: Set[Any]  # should be: Set[DKT]
-    value: Dict[Any, Any]  # should be: Dict[DKT, DV]
+    known_absent: Set[DKT]
+    value: Dict[DKT, DV]
 
     def __len__(self) -> int:
         return len(self.value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 01ad02af67..e73cf66080 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import OrderedDict
-from typing import Any, Generic, Optional, TypeVar, Union, overload
+from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload
 
 import attr
 from typing_extensions import Literal
@@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
+        self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()
 
         self.iterable = iterable
 
@@ -84,9 +84,7 @@ class ExpiringCache(Generic[KT, VT]):
             return
 
         def f() -> "defer.Deferred[None]":
-            return run_as_background_process(
-                "prune_cache_%s" % self._cache_name, self._prune_cache
-            )
+            return run_as_background_process("prune_cache", self._prune_cache)
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
@@ -100,7 +98,10 @@ class ExpiringCache(Generic[KT, VT]):
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.size, len(value.value))
+                # type-ignore, here and below: if self.iterable is true, then the value
+                # type VT should be Sized (i.e. have a __len__ method). We don't enforce
+                # this via the type system at present.
+                self.metrics.inc_evictions(EvictionReason.size, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.size)
 
@@ -134,7 +135,7 @@ class ExpiringCache(Generic[KT, VT]):
             return default
 
         if self.iterable:
-            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))  # type: ignore[arg-type]
         else:
             self.metrics.inc_evictions(EvictionReason.invalidation)
 
@@ -182,7 +183,7 @@ class ExpiringCache(Generic[KT, VT]):
         for k in keys_to_delete:
             value = self._cache.pop(k)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.time, len(value.value))
+                self.metrics.inc_evictions(EvictionReason.time, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.time)
 
@@ -195,7 +196,8 @@ class ExpiringCache(Generic[KT, VT]):
 
     def __len__(self) -> int:
         if self.iterable:
-            return sum(len(entry.value) for entry in self._cache.values())
+            g: Iterable[int] = (len(entry.value) for entry in self._cache.values())  # type: ignore[arg-type]
+            return sum(g)
         else:
             return len(self._cache)
 
@@ -218,6 +220,6 @@ class ExpiringCache(Generic[KT, VT]):
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _CacheEntry:
+class _CacheEntry(Generic[VT]):
     time: int
-    value: Any
+    value: VT
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index f6b3ee31e4..48a6e4a906 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data: Dict[KT, _CacheEntry] = {}
+        self._data: Dict[KT, _CacheEntry[KT, VT]] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list: SortedList[_CacheEntry] = SortedList()
+        self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()
 
         self._timer = timer
 
@@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]):
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
-class _CacheEntry:  # Should be Generic[KT, VT]. See python-attrs/attrs#313
+class _CacheEntry(Generic[KT, VT]):
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
     expiry_time: float
     ttl: float
-    key: Any  # should be KT
-    value: Any  # should be VT
+    key: KT
+    value: VT
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index 214eb17fbc..fecf829ade 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -136,7 +136,7 @@ class GAIResolver:
 
     # The types on IHostnameResolver is incorrect in Twisted, see
     # https://twistedmatrix.com/trac/ticket/10276
-    def resolveHostName(  # type: ignore[override]
+    def resolveHostName(
         self,
         resolutionReceiver: IResolutionReceiver,
         hostName: str,
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 9e89aeb748..caf13b3474 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -15,11 +15,14 @@
 import logging
 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
 
-from prometheus_client import Gauge
-
 from twisted.python.failure import Failure
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.logging.context import nested_logging_context
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.types import JsonMapping, ScheduledTask, TaskStatus
 from synapse.util.stringutils import random_string
 
@@ -29,12 +32,6 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-running_tasks_gauge = Gauge(
-    "synapse_scheduler_running_tasks",
-    "The number of concurrent running tasks handled by the TaskScheduler",
-)
-
-
 class TaskScheduler:
     """
     This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
@@ -69,6 +66,8 @@ class TaskScheduler:
     # Precision of the scheduler, evaluation of tasks to run will only happen
     # every `SCHEDULE_INTERVAL_MS` ms
     SCHEDULE_INTERVAL_MS = 1 * 60 * 1000  # 1mn
+    # How often to clean up old tasks.
+    CLEANUP_INTERVAL_MS = 30 * 60 * 1000
     # Time before a complete or failed task is deleted from the DB
     KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000  # 1 week
     # Maximum number of tasks that can run at the same time
@@ -77,6 +76,7 @@ class TaskScheduler:
     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
 
     def __init__(self, hs: "HomeServer"):
+        self._hs = hs
         self._store = hs.get_datastores().main
         self._clock = hs.get_clock()
         self._running_tasks: Set[str] = set()
@@ -90,15 +90,25 @@ class TaskScheduler:
         ] = {}
         self._run_background_tasks = hs.config.worker.run_background_tasks
 
+        # Flag to make sure we only try and launch new tasks once at a time.
+        self._launching_new_tasks = False
+
         if self._run_background_tasks:
             self._clock.looping_call(
-                run_as_background_process,
+                self._launch_scheduled_tasks,
+                TaskScheduler.SCHEDULE_INTERVAL_MS,
+            )
+            self._clock.looping_call(
+                self._clean_scheduled_tasks,
                 TaskScheduler.SCHEDULE_INTERVAL_MS,
-                "handle_scheduled_tasks",
-                self._handle_scheduled_tasks,
             )
-        else:
-            self.replication_client = hs.get_replication_command_handler()
+
+        LaterGauge(
+            "synapse_scheduler_running_tasks",
+            "The number of concurrent running tasks handled by the TaskScheduler",
+            labels=None,
+            caller=lambda: len(self._running_tasks),
+        )
 
     def register_action(
         self,
@@ -133,7 +143,7 @@ class TaskScheduler:
         params: Optional[JsonMapping] = None,
     ) -> str:
         """Schedule a new potentially resumable task. A function matching the specified
-        `action` should have been previously registered with `register_action`.
+        `action` should have be registered with `register_action` before the task is run.
 
         Args:
             action: the name of a previously registered action
@@ -149,11 +159,6 @@ class TaskScheduler:
         Returns:
             The id of the scheduled task
         """
-        if action not in self._actions:
-            raise Exception(
-                f"No function associated with action {action} of the scheduled task"
-            )
-
         status = TaskStatus.SCHEDULED
         if timestamp is None or timestamp < self._clock.time_msec():
             timestamp = self._clock.time_msec()
@@ -175,7 +180,7 @@ class TaskScheduler:
             if self._run_background_tasks:
                 await self._launch_task(task)
             else:
-                self.replication_client.send_new_active_task(task.id)
+                self._hs.get_replication_command_handler().send_new_active_task(task.id)
 
         return task.id
 
@@ -239,6 +244,7 @@ class TaskScheduler:
         resource_id: Optional[str] = None,
         statuses: Optional[List[TaskStatus]] = None,
         max_timestamp: Optional[int] = None,
+        limit: Optional[int] = None,
     ) -> List[ScheduledTask]:
         """Get a list of tasks. Returns all the tasks if no args is provided.
 
@@ -252,6 +258,7 @@ class TaskScheduler:
             statuses: Limit the returned tasks to the specific statuses
             max_timestamp: Limit the returned tasks to the ones that have
                 a timestamp inferior to the specified one
+            limit: Only return `limit` number of rows if set.
 
         Returns
             A list of `ScheduledTask`, ordered by increasing timestamps
@@ -261,6 +268,7 @@ class TaskScheduler:
             resource_id=resource_id,
             statuses=statuses,
             max_timestamp=max_timestamp,
+            limit=limit,
         )
 
     async def delete_task(self, id: str) -> None:
@@ -278,34 +286,58 @@ class TaskScheduler:
             raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
         await self._store.delete_scheduled_task(id)
 
-    async def _handle_scheduled_tasks(self) -> None:
-        """Main loop taking care of launching tasks and cleaning up old ones."""
-        await self._launch_scheduled_tasks()
-        await self._clean_scheduled_tasks()
+    def launch_task_by_id(self, id: str) -> None:
+        """Try launching the task with the given ID."""
+        # Don't bother trying to launch new tasks if we're already at capacity.
+        if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
+            return
+
+        run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
+
+    async def _launch_task_by_id(self, id: str) -> None:
+        """Helper async function for `launch_task_by_id`."""
+        task = await self.get_task(id)
+        if task:
+            await self._launch_task(task)
 
+    @wrap_as_background_process("launch_scheduled_tasks")
     async def _launch_scheduled_tasks(self) -> None:
         """Retrieve and launch scheduled tasks that should be running at that time."""
-        for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
-            await self._launch_task(task)
-        for task in await self.get_tasks(
-            statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
-        ):
-            await self._launch_task(task)
+        # Don't bother trying to launch new tasks if we're already at capacity.
+        if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
+            return
+
+        if self._launching_new_tasks:
+            return
 
-        running_tasks_gauge.set(len(self._running_tasks))
+        self._launching_new_tasks = True
 
+        try:
+            for task in await self.get_tasks(
+                statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
+            ):
+                await self._launch_task(task)
+            for task in await self.get_tasks(
+                statuses=[TaskStatus.SCHEDULED],
+                max_timestamp=self._clock.time_msec(),
+                limit=self.MAX_CONCURRENT_RUNNING_TASKS,
+            ):
+                await self._launch_task(task)
+
+        finally:
+            self._launching_new_tasks = False
+
+    @wrap_as_background_process("clean_scheduled_tasks")
     async def _clean_scheduled_tasks(self) -> None:
         """Clean old complete or failed jobs to avoid clutter the DB."""
+        now = self._clock.time_msec()
         for task in await self._store.get_scheduled_tasks(
-            statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
+            statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE],
+            max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS,
         ):
             # FAILED and COMPLETE tasks should never be running
             assert task.id not in self._running_tasks
-            if (
-                self._clock.time_msec()
-                > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
-            ):
-                await self._store.delete_scheduled_task(task.id)
+            await self._store.delete_scheduled_task(task.id)
 
     async def _launch_task(self, task: ScheduledTask) -> None:
         """Launch a scheduled task now.
@@ -315,30 +347,37 @@ class TaskScheduler:
         """
         assert self._run_background_tasks
 
-        assert task.action in self._actions
+        if task.action not in self._actions:
+            raise Exception(
+                f"No function associated with action {task.action} of the scheduled task {task.id}"
+            )
         function = self._actions[task.action]
 
         async def wrapper() -> None:
-            try:
-                (status, result, error) = await function(task)
-            except Exception:
-                f = Failure()
-                logger.error(
-                    f"scheduled task {task.id} failed",
-                    exc_info=(f.type, f.value, f.getTracebackObject()),
+            with nested_logging_context(task.id):
+                try:
+                    (status, result, error) = await function(task)
+                except Exception:
+                    f = Failure()
+                    logger.error(
+                        f"scheduled task {task.id} failed",
+                        exc_info=(f.type, f.value, f.getTracebackObject()),
+                    )
+                    status = TaskStatus.FAILED
+                    result = None
+                    error = f.getErrorMessage()
+
+                await self._store.update_scheduled_task(
+                    task.id,
+                    self._clock.time_msec(),
+                    status=status,
+                    result=result,
+                    error=error,
                 )
-                status = TaskStatus.FAILED
-                result = None
-                error = f.getErrorMessage()
-
-            await self._store.update_scheduled_task(
-                task.id,
-                self._clock.time_msec(),
-                status=status,
-                result=result,
-                error=error,
-            )
-            self._running_tasks.remove(task.id)
+                self._running_tasks.remove(task.id)
+
+            # Try launch a new task since we've finished with this one.
+            self._clock.call_later(1, self._launch_scheduled_tasks)
 
         if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
             return
@@ -356,5 +395,4 @@ class TaskScheduler:
 
         self._running_tasks.add(task.id)
         await self.update_task(task.id, status=TaskStatus.ACTIVE)
-        description = f"{task.id}-{task.action}"
-        run_as_background_process(description, wrapper)
+        run_as_background_process(f"task-{task.action}", wrapper)