summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11306.feature1
-rw-r--r--docs/modules/background_update_controller_callbacks.md71
-rw-r--r--docs/modules/writing_a_module.md12
-rwxr-xr-xsetup.py4
-rw-r--r--synapse/module_api/__init__.py54
-rw-r--r--synapse/storage/background_updates.py192
-rw-r--r--tests/push/test_email.py9
-rw-r--r--tests/rest/admin/test_background_updates.py2
-rw-r--r--tests/storage/test_background_update.py104
-rw-r--r--tests/storage/test_event_chain.py4
-rw-r--r--tests/storage/test_user_directory.py5
-rw-r--r--tests/unittest.py10
12 files changed, 407 insertions, 61 deletions
diff --git a/changelog.d/11306.feature b/changelog.d/11306.feature
new file mode 100644
index 0000000000..aba3292015
--- /dev/null
+++ b/changelog.d/11306.feature
@@ -0,0 +1 @@
+Add plugin support for controlling database background updates.
diff --git a/docs/modules/background_update_controller_callbacks.md b/docs/modules/background_update_controller_callbacks.md
new file mode 100644
index 0000000000..b3e7c259f4
--- /dev/null
+++ b/docs/modules/background_update_controller_callbacks.md
@@ -0,0 +1,71 @@
+# Background update controller callbacks
+
+Background update controller callbacks allow module developers to control (e.g. rate-limit)
+how database background updates are run. A database background update is an operation
+Synapse runs on its database in the background after it starts. It's usually used to run
+database operations that would take too long if they were run at the same time as schema
+updates (which are run on startup) and delay Synapse's startup too much: populating a
+table with a big amount of data, adding an index on a big table, deleting superfluous data,
+etc.
+
+Background update controller callbacks can be registered using the module API's
+`register_background_update_controller_callbacks` method. Only the first module (in order
+of appearance in Synapse's configuration file) calling this method can register background
+update controller callbacks, subsequent calls are ignored.
+
+The available background update controller callbacks are:
+
+### `on_update`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int]
+```
+
+Called when about to do an iteration of a background update. The module is given the name
+of the update, the name of the database, and a flag to indicate whether the background
+update will happen in one go and may take a long time (e.g. creating indices). If this last
+argument is set to `False`, the update will be run in batches.
+
+The module must return an async context manager. It will be entered before Synapse runs a 
+background update; this should return the desired duration of the iteration, in
+milliseconds.
+
+The context manager will be exited when the iteration completes. Note that the duration
+returned by the context manager is a target, and an iteration may take substantially longer
+or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored.
+
+__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is
+because asynchronous operations are expected to be run by the async context manager.
+
+This callback is required when registering any other background update controller callback.
+
+### `default_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def default_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before the first iteration of a background update, with the name of the update and
+of the database. The module must return the number of elements to process in this first
+iteration.
+
+If this callback is not defined, Synapse will use a default value of 100.
+
+### `min_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def min_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before running a new batch for a background update, with the name of the update and
+of the database. The module must return an integer representing the minimum number of
+elements to process in this iteration. This number must be at least 1, and is used to
+ensure that progress is always made.
+
+If this callback is not defined, Synapse will use a default value of 100.
diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md
index 7764e06692..e7c0ffad58 100644
--- a/docs/modules/writing_a_module.md
+++ b/docs/modules/writing_a_module.md
@@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method.
 ## Registering a callback
 
 Modules can use Synapse's module API to register callbacks. Callbacks are functions that
-Synapse will call when performing specific actions. Callbacks must be asynchronous, and
-are split in categories. A single module may implement callbacks from multiple categories,
-and is under no obligation to implement all callbacks from the categories it registers
-callbacks for.
+Synapse will call when performing specific actions. Callbacks must be asynchronous (unless
+specified otherwise), and are split in categories. A single module may implement callbacks
+from multiple categories, and is under no obligation to implement all callbacks from the
+categories it registers callbacks for.
 
 Modules can register callbacks using one of the module API's `register_[...]_callbacks`
 methods. The callback functions are passed to these methods as keyword arguments, with
-the callback name as the argument name and the function as its value. This is demonstrated
-in the example below. A `register_[...]_callbacks` method exists for each category.
+the callback name as the argument name and the function as its value. A
+`register_[...]_callbacks` method exists for each category.
 
 Callbacks for each category can be found on their respective page of the
 [Synapse documentation website](https://matrix-org.github.io/synapse).
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 0ce8beb004..ad99b3bd2c 100755
--- a/setup.py
+++ b/setup.py
@@ -119,7 +119,9 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
 # Tests assume that all optional dependencies are installed.
 #
 # parameterized_class decorator was introduced in parameterized 0.7.0
-CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
+#
+# We use `mock` library as that backports `AsyncMock` to Python 3.6
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
 
 CONDITIONAL_REQUIREMENTS["dev"] = (
     CONDITIONAL_REQUIREMENTS["lint"]
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 19e570ede2..a8154168be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -82,10 +82,19 @@ from synapse.http.server import (
 )
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+    defer_to_thread,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client.login import LoginResponse
 from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+    DEFAULT_BATCH_SIZE_CALLBACK,
+    MIN_BATCH_SIZE_CALLBACK,
+    ON_UPDATE_CALLBACK,
+)
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -311,6 +320,24 @@ class ModuleApi:
             auth_checkers=auth_checkers,
         )
 
+    def register_background_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Registers background update controller callbacks.
+
+        Added in Synapse v1.49.0.
+        """
+
+        for db in self._hs.get_datastores().databases:
+            db.updates.register_update_controller_callbacks(
+                on_update=on_update,
+                default_batch_size=default_batch_size,
+                min_batch_size=min_batch_size,
+            )
+
     def register_web_resource(self, path: str, resource: Resource) -> None:
         """Registers a web resource to be served at the given path.
 
@@ -995,6 +1022,11 @@ class ModuleApi:
                 f,
             )
 
+    async def sleep(self, seconds: float) -> None:
+        """Sleeps for the given number of seconds."""
+
+        await self._clock.sleep(seconds)
+
     async def send_mail(
         self,
         recipient: str,
@@ -1149,6 +1181,26 @@ class ModuleApi:
 
         return {key: state_events[event_id] for key, event_id in state_ids.items()}
 
+    async def defer_to_thread(
+        self,
+        f: Callable[..., T],
+        *args: Any,
+        **kwargs: Any,
+    ) -> T:
+        """Runs the given function in a separate thread from Synapse's thread pool.
+
+        Added in Synapse v1.49.0.
+
+        Args:
+            f: The function to run.
+            args: The function's arguments.
+            kwargs: The function's keyword arguments.
+
+        Returns:
+            The return value of the function once ran in a thread.
+        """
+        return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+)
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.types import Connection
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
 
 from . import engines
 
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+    """A handler for a given background update.
+
+    Attributes:
+        callback: The function to call to make progress on the background
+            update.
+        oneshot: Wether the update is likely to happen all in one go, ignoring
+            the supplied target duration, e.g. index creation. This is used by
+            the update controller to help correctly schedule the update.
+    """
+
+    callback: Callable[[JsonDict, int], Awaitable[int]]
+    oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+    BACKGROUND_UPDATE_INTERVAL_MS = 1000
+    BACKGROUND_UPDATE_DURATION_MS = 100
+
+    def __init__(self, sleep: bool, clock: Clock):
+        self._sleep = sleep
+        self._clock = clock
+
+    async def __aenter__(self) -> int:
+        if self._sleep:
+            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+        return self.BACKGROUND_UPDATE_DURATION_MS
+
+    async def __aexit__(self, *exc) -> None:
+        pass
+
+
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
@@ -84,20 +133,22 @@ class BackgroundUpdater:
 
     MINIMUM_BACKGROUND_BATCH_SIZE = 1
     DEFAULT_BACKGROUND_BATCH_SIZE = 100
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
 
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
+        self._database_name = database.name()
+
         # if a background update is currently running, its name.
         self._current_background_update: Optional[str] = None
 
+        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
         self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
-        self._background_update_handlers: Dict[
-            str, Callable[[JsonDict, int], Awaitable[int]]
-        ] = {}
+        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
         self._all_done = False
 
         # Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+    def register_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Register callbacks from a module for each hook."""
+        if self._on_update_callback is not None:
+            logger.warning(
+                "More than one module tried to register callbacks for controlling"
+                " background updates. Only the callbacks registered by the first module"
+                " (in order of appearance in Synapse's configuration file) that tried to"
+                " do so will be called."
+            )
+
+            return
+
+        self._on_update_callback = on_update
+
+        if default_batch_size is not None:
+            self._default_batch_size_callback = default_batch_size
+
+        if min_batch_size is not None:
+            self._min_batch_size_callback = min_batch_size
+
+    def _get_context_manager_for_update(
+        self,
+        sleep: bool,
+        update_name: str,
+        database_name: str,
+        oneshot: bool,
+    ) -> AsyncContextManager[int]:
+        """Get a context manager to run a background update with.
+
+        If a module has registered a `update_handler` callback, use the context manager
+        it returns.
+
+        Otherwise, returns a context manager that will return a default value, optionally
+        sleeping if needed.
+
+        Args:
+            sleep: Whether we can sleep between updates.
+            update_name: The name of the update.
+            database_name: The name of the database the update is being run on.
+            oneshot: Whether the update will complete all in one go, e.g. index creation.
+                In such cases the returned target duration is ignored.
+
+        Returns:
+            The target duration in milliseconds that the background update should run for.
+
+            Note: this is a *target*, and an iteration may take substantially longer or
+            shorter.
+        """
+        if self._on_update_callback is not None:
+            return self._on_update_callback(update_name, database_name, oneshot)
+
+        return _BackgroundUpdateContextManager(sleep, self._clock)
+
+    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+        """The batch size to use for the first iteration of a new background
+        update.
+        """
+        if self._default_batch_size_callback is not None:
+            return await self._default_batch_size_callback(update_name, database_name)
+
+        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+        """A lower bound on the batch size of a new background update.
+
+        Used to ensure that progress is always made. Must be greater than 0.
+        """
+        if self._min_batch_size_callback is not None:
+            return await self._min_batch_size_callback(update_name, database_name)
+
+        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
 
@@ -135,13 +263,8 @@ class BackgroundUpdater:
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
-                if sleep:
-                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
                 try:
-                    result = await self.do_next_background_update(
-                        self.BACKGROUND_UPDATE_DURATION_MS
-                    )
+                    result = await self.do_next_background_update(sleep)
                 except Exception:
                     logger.exception("Error doing update")
                 else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
 
         return not update_exists
 
-    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+    async def do_next_background_update(self, sleep: bool = True) -> bool:
         """Does some amount of work on the next queued background update
 
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms: How long we want to spend updating.
+            sleep: Whether to limit how quickly we run background updates or
+                not.
+
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -252,7 +377,19 @@ class BackgroundUpdater:
 
             self._current_background_update = upd["update_name"]
 
-        await self._do_background_update(desired_duration_ms)
+        # We have a background update to run, otherwise we would have returned
+        # early.
+        assert self._current_background_update is not None
+        update_info = self._background_update_handlers[self._current_background_update]
+
+        async with self._get_context_manager_for_update(
+            sleep=sleep,
+            update_name=self._current_background_update,
+            database_name=self._database_name,
+            oneshot=update_info.oneshot,
+        ) as desired_duration_ms:
+            await self._do_background_update(desired_duration_ms)
+
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
-        update_handler = self._background_update_handlers[update_name]
+        update_handler = self._background_update_handlers[update_name].callback
 
         performance = self._background_update_performance.get(update_name)
 
@@ -273,9 +410,14 @@ class BackgroundUpdater:
         if items_per_ms is not None:
             batch_size = int(desired_duration_ms * items_per_ms)
             # Clamp the batch size so that we always make progress
-            batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+            batch_size = max(
+                batch_size,
+                await self._min_batch_size(update_name, self._database_name),
+            )
         else:
-            batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+            batch_size = await self._default_batch_size(
+                update_name, self._database_name
+            )
 
         progress_json = await self.db_pool.simple_select_one_onecol(
             "background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
 
         duration_ms = time_stop - time_start
 
+        performance.update(items_updated, duration_ms)
+
         logger.info(
             "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
             batch_size,
         )
 
-        performance.update(items_updated, duration_ms)
-
         return len(self._background_update_performance)
 
     def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
             update_name: The name of the update that this code handles.
             update_handler: The function that does the update.
         """
-        self._background_update_handlers[update_name] = update_handler
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            update_handler
+        )
 
     def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
             await self._end_background_update(update_name)
             return 1
 
-        self.register_background_update_handler(update_name, updater)
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 90f800e564..f8cba7b645 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         self.auth_handler = hs.get_auth_handler()
+        self.store = hs.get_datastore()
 
     def test_need_validated_email(self):
         """Test that we can only add an email pusher if the user has validated
@@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase):
         self.hs.get_datastore().db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(
-            self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
-                by=0.1,
-            )
+        self.wait_for_background_updates()
 
         # Check that all pushers with unlinked addresses were deleted
         pushers = self.get_success(
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index cd5c60b65c..62f242baf6 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -135,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         self._register_bg_update()
 
         self.store.db_pool.updates.start_doing_background_updates()
-        self.reactor.pump([1.0, 1.0])
+        self.reactor.pump([1.0, 1.0, 1.0])
 
         channel = self.make_request(
             "GET",
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a5f5ebad41..216d816d56 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,8 +1,11 @@
-from unittest.mock import Mock
+from mock import Mock
+
+from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -20,10 +23,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
     def test_do_background_update(self):
         # the time we claim it takes to update one item when running the update
-        duration_ms = 4200
+        duration_ms = 10
 
         # the target runtime for each bg update
-        target_background_update_duration_ms = 5000000
+        target_background_update_duration_ms = 100
 
         store = self.hs.get_datastore()
         self.get_success(
@@ -48,10 +51,8 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
         res = self.get_success(
-            self.updates.do_next_background_update(
-                target_background_update_duration_ms
-            ),
-            by=0.1,
+            self.updates.do_next_background_update(False),
+            by=0.01,
         )
         self.assertFalse(res)
 
@@ -74,16 +75,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = self.get_success(
-            self.updates.do_next_background_update(target_background_update_duration_ms)
-        )
+        result = self.get_success(self.updates.do_next_background_update(False))
         self.assertFalse(result)
         self.update_handler.assert_called_once()
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = self.get_success(
-            self.updates.do_next_background_update(target_background_update_duration_ms)
-        )
+        result = self.get_success(self.updates.do_next_background_update(False))
         self.assertTrue(result)
         self.assertFalse(self.update_handler.called)
+
+
+class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, homeserver):
+        self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+        # the base test class should have run the real bg updates for us
+        self.assertTrue(
+            self.get_success(self.updates.has_completed_background_updates())
+        )
+
+        self.update_deferred = Deferred()
+        self.update_handler = Mock(return_value=self.update_deferred)
+        self.updates.register_background_update_handler(
+            "test_update", self.update_handler
+        )
+
+        # Mock out the AsyncContextManager
+        self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
+        self._update_ctx_manager.__aenter__ = Mock(
+            return_value=make_awaitable(None),
+        )
+        self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+
+        # Mock out the `update_handler` callback
+        self._on_update = Mock(return_value=self._update_ctx_manager)
+
+        # Define a default batch size value that's not the same as the internal default
+        # value (100).
+        self._default_batch_size = 500
+
+        # Register the callbacks with more mocks
+        self.hs.get_module_api().register_background_update_controller_callbacks(
+            on_update=self._on_update,
+            min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
+            default_batch_size=Mock(
+                return_value=make_awaitable(self._default_batch_size),
+            ),
+        )
+
+    def test_controller(self):
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": "{}"},
+            )
+        )
+
+        # Set the return value for the context manager.
+        enter_defer = Deferred()
+        self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
+
+        # Start the background update.
+        do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
+
+        self.pump()
+
+        # `run_update` should have been called, but the update handler won't be
+        # called until the `enter_defer` (returned by `__aenter__`) is resolved.
+        self._on_update.assert_called_once_with(
+            "test_update",
+            "master",
+            False,
+        )
+        self.assertFalse(do_update_d.called)
+        self.assertFalse(self.update_deferred.called)
+
+        # Resolving the `enter_defer` should call the update handler, which then
+        # blocks.
+        enter_defer.callback(100)
+        self.pump()
+        self.update_handler.assert_called_once_with({}, self._default_batch_size)
+        self.assertFalse(self.update_deferred.called)
+        self._update_ctx_manager.__aexit__.assert_not_called()
+
+        # Resolving the update handler deferred should cause the
+        # `do_next_background_update` to finish and return
+        self.update_deferred.callback(100)
+        self.pump()
+        self._update_ctx_manager.__aexit__.assert_called()
+        self.get_success(do_update_d)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b31c5eb5ec..7b7f6c349e 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         ):
             iterations += 1
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
         # Ensure that we did actually take multiple iterations to process the
@@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         ):
             iterations += 1
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
         # Ensure that we did actually take multiple iterations to process the
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 37cf7bb232..7f5b28aed8 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -23,6 +23,7 @@ from synapse.rest import admin
 from synapse.rest.client import login, register, room
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.storage.background_updates import _BackgroundUpdateHandler
 from synapse.storage.roommember import ProfileInfo
 from synapse.util import Clock
 
@@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
 
         with mock.patch.dict(
             self.store.db_pool.updates._background_update_handlers,
-            populate_user_directory_process_users=mocked_process_users,
+            populate_user_directory_process_users=_BackgroundUpdateHandler(
+                mocked_process_users,
+            ),
         ):
             self._purge_and_rebuild_user_dir()
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 165aafc574..eea0903f05 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,17 +331,16 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """
-        Block until all background database updates have completed.
+        """Block until all background database updates have completed.
 
-        Note that callers must ensure that's a store property created on the
+        Note that callers must ensure there's a store property created on the
         testcase.
         """
         while not self.get_success(
             self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
     def make_homeserver(self, reactor, clock):
@@ -500,8 +499,7 @@ class HomeserverTestCase(TestCase):
 
         async def run_bg_updates():
             with LoggingContext("run_bg_updates"):
-                while not await stor.db_pool.updates.has_completed_background_updates():
-                    await stor.db_pool.updates.do_next_background_update(1)
+                self.get_success(stor.db_pool.updates.run_background_updates(False))
 
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()