summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py192
1 files changed, 169 insertions, 23 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+)
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.types import Connection
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
 
 from . import engines
 
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+    """A handler for a given background update.
+
+    Attributes:
+        callback: The function to call to make progress on the background
+            update.
+        oneshot: Wether the update is likely to happen all in one go, ignoring
+            the supplied target duration, e.g. index creation. This is used by
+            the update controller to help correctly schedule the update.
+    """
+
+    callback: Callable[[JsonDict, int], Awaitable[int]]
+    oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+    BACKGROUND_UPDATE_INTERVAL_MS = 1000
+    BACKGROUND_UPDATE_DURATION_MS = 100
+
+    def __init__(self, sleep: bool, clock: Clock):
+        self._sleep = sleep
+        self._clock = clock
+
+    async def __aenter__(self) -> int:
+        if self._sleep:
+            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+        return self.BACKGROUND_UPDATE_DURATION_MS
+
+    async def __aexit__(self, *exc) -> None:
+        pass
+
+
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
@@ -84,20 +133,22 @@ class BackgroundUpdater:
 
     MINIMUM_BACKGROUND_BATCH_SIZE = 1
     DEFAULT_BACKGROUND_BATCH_SIZE = 100
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
 
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
+        self._database_name = database.name()
+
         # if a background update is currently running, its name.
         self._current_background_update: Optional[str] = None
 
+        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
         self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
-        self._background_update_handlers: Dict[
-            str, Callable[[JsonDict, int], Awaitable[int]]
-        ] = {}
+        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
         self._all_done = False
 
         # Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+    def register_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Register callbacks from a module for each hook."""
+        if self._on_update_callback is not None:
+            logger.warning(
+                "More than one module tried to register callbacks for controlling"
+                " background updates. Only the callbacks registered by the first module"
+                " (in order of appearance in Synapse's configuration file) that tried to"
+                " do so will be called."
+            )
+
+            return
+
+        self._on_update_callback = on_update
+
+        if default_batch_size is not None:
+            self._default_batch_size_callback = default_batch_size
+
+        if min_batch_size is not None:
+            self._min_batch_size_callback = min_batch_size
+
+    def _get_context_manager_for_update(
+        self,
+        sleep: bool,
+        update_name: str,
+        database_name: str,
+        oneshot: bool,
+    ) -> AsyncContextManager[int]:
+        """Get a context manager to run a background update with.
+
+        If a module has registered a `update_handler` callback, use the context manager
+        it returns.
+
+        Otherwise, returns a context manager that will return a default value, optionally
+        sleeping if needed.
+
+        Args:
+            sleep: Whether we can sleep between updates.
+            update_name: The name of the update.
+            database_name: The name of the database the update is being run on.
+            oneshot: Whether the update will complete all in one go, e.g. index creation.
+                In such cases the returned target duration is ignored.
+
+        Returns:
+            The target duration in milliseconds that the background update should run for.
+
+            Note: this is a *target*, and an iteration may take substantially longer or
+            shorter.
+        """
+        if self._on_update_callback is not None:
+            return self._on_update_callback(update_name, database_name, oneshot)
+
+        return _BackgroundUpdateContextManager(sleep, self._clock)
+
+    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+        """The batch size to use for the first iteration of a new background
+        update.
+        """
+        if self._default_batch_size_callback is not None:
+            return await self._default_batch_size_callback(update_name, database_name)
+
+        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+        """A lower bound on the batch size of a new background update.
+
+        Used to ensure that progress is always made. Must be greater than 0.
+        """
+        if self._min_batch_size_callback is not None:
+            return await self._min_batch_size_callback(update_name, database_name)
+
+        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
 
@@ -135,13 +263,8 @@ class BackgroundUpdater:
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
-                if sleep:
-                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
                 try:
-                    result = await self.do_next_background_update(
-                        self.BACKGROUND_UPDATE_DURATION_MS
-                    )
+                    result = await self.do_next_background_update(sleep)
                 except Exception:
                     logger.exception("Error doing update")
                 else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
 
         return not update_exists
 
-    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+    async def do_next_background_update(self, sleep: bool = True) -> bool:
         """Does some amount of work on the next queued background update
 
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms: How long we want to spend updating.
+            sleep: Whether to limit how quickly we run background updates or
+                not.
+
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -252,7 +377,19 @@ class BackgroundUpdater:
 
             self._current_background_update = upd["update_name"]
 
-        await self._do_background_update(desired_duration_ms)
+        # We have a background update to run, otherwise we would have returned
+        # early.
+        assert self._current_background_update is not None
+        update_info = self._background_update_handlers[self._current_background_update]
+
+        async with self._get_context_manager_for_update(
+            sleep=sleep,
+            update_name=self._current_background_update,
+            database_name=self._database_name,
+            oneshot=update_info.oneshot,
+        ) as desired_duration_ms:
+            await self._do_background_update(desired_duration_ms)
+
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
-        update_handler = self._background_update_handlers[update_name]
+        update_handler = self._background_update_handlers[update_name].callback
 
         performance = self._background_update_performance.get(update_name)
 
@@ -273,9 +410,14 @@ class BackgroundUpdater:
         if items_per_ms is not None:
             batch_size = int(desired_duration_ms * items_per_ms)
             # Clamp the batch size so that we always make progress
-            batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+            batch_size = max(
+                batch_size,
+                await self._min_batch_size(update_name, self._database_name),
+            )
         else:
-            batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+            batch_size = await self._default_batch_size(
+                update_name, self._database_name
+            )
 
         progress_json = await self.db_pool.simple_select_one_onecol(
             "background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
 
         duration_ms = time_stop - time_start
 
+        performance.update(items_updated, duration_ms)
+
         logger.info(
             "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
             batch_size,
         )
 
-        performance.update(items_updated, duration_ms)
-
         return len(self._background_update_performance)
 
     def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
             update_name: The name of the update that this code handles.
             update_handler: The function that does the update.
         """
-        self._background_update_handlers[update_name] = update_handler
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            update_handler
+        )
 
     def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
             await self._end_background_update(update_name)
             return 1
 
-        self.register_background_update_handler(update_name, updater)
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.