diff --git a/changelog.d/11306.feature b/changelog.d/11306.feature
new file mode 100644
index 0000000000..aba3292015
--- /dev/null
+++ b/changelog.d/11306.feature
@@ -0,0 +1 @@
+Add plugin support for controlling database background updates.
diff --git a/docs/modules/background_update_controller_callbacks.md b/docs/modules/background_update_controller_callbacks.md
new file mode 100644
index 0000000000..b3e7c259f4
--- /dev/null
+++ b/docs/modules/background_update_controller_callbacks.md
@@ -0,0 +1,71 @@
+# Background update controller callbacks
+
+Background update controller callbacks allow module developers to control (e.g. rate-limit)
+how database background updates are run. A database background update is an operation
+Synapse runs on its database in the background after it starts. It's usually used to run
+database operations that would take too long if they were run at the same time as schema
+updates (which are run on startup) and delay Synapse's startup too much: populating a
+table with a big amount of data, adding an index on a big table, deleting superfluous data,
+etc.
+
+Background update controller callbacks can be registered using the module API's
+`register_background_update_controller_callbacks` method. Only the first module (in order
+of appearance in Synapse's configuration file) calling this method can register background
+update controller callbacks, subsequent calls are ignored.
+
+The available background update controller callbacks are:
+
+### `on_update`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int]
+```
+
+Called when about to do an iteration of a background update. The module is given the name
+of the update, the name of the database, and a flag to indicate whether the background
+update will happen in one go and may take a long time (e.g. creating indices). If this last
+argument is set to `False`, the update will be run in batches.
+
+The module must return an async context manager. It will be entered before Synapse runs a
+background update; this should return the desired duration of the iteration, in
+milliseconds.
+
+The context manager will be exited when the iteration completes. Note that the duration
+returned by the context manager is a target, and an iteration may take substantially longer
+or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored.
+
+__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is
+because asynchronous operations are expected to be run by the async context manager.
+
+This callback is required when registering any other background update controller callback.
+
+### `default_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def default_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before the first iteration of a background update, with the name of the update and
+of the database. The module must return the number of elements to process in this first
+iteration.
+
+If this callback is not defined, Synapse will use a default value of 100.
+
+### `min_batch_size`
+
+_First introduced in Synapse v1.49.0_
+
+```python
+async def min_batch_size(update_name: str, database_name: str) -> int
+```
+
+Called before running a new batch for a background update, with the name of the update and
+of the database. The module must return an integer representing the minimum number of
+elements to process in this iteration. This number must be at least 1, and is used to
+ensure that progress is always made.
+
+If this callback is not defined, Synapse will use a default value of 100.
diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md
index 7764e06692..e7c0ffad58 100644
--- a/docs/modules/writing_a_module.md
+++ b/docs/modules/writing_a_module.md
@@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method.
## Registering a callback
Modules can use Synapse's module API to register callbacks. Callbacks are functions that
-Synapse will call when performing specific actions. Callbacks must be asynchronous, and
-are split in categories. A single module may implement callbacks from multiple categories,
-and is under no obligation to implement all callbacks from the categories it registers
-callbacks for.
+Synapse will call when performing specific actions. Callbacks must be asynchronous (unless
+specified otherwise), and are split in categories. A single module may implement callbacks
+from multiple categories, and is under no obligation to implement all callbacks from the
+categories it registers callbacks for.
Modules can register callbacks using one of the module API's `register_[...]_callbacks`
methods. The callback functions are passed to these methods as keyword arguments, with
-the callback name as the argument name and the function as its value. This is demonstrated
-in the example below. A `register_[...]_callbacks` method exists for each category.
+the callback name as the argument name and the function as its value. A
+`register_[...]_callbacks` method exists for each category.
Callbacks for each category can be found on their respective page of the
[Synapse documentation website](https://matrix-org.github.io/synapse).
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 0ce8beb004..ad99b3bd2c 100755
--- a/setup.py
+++ b/setup.py
@@ -119,7 +119,9 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
# Tests assume that all optional dependencies are installed.
#
# parameterized_class decorator was introduced in parameterized 0.7.0
-CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
+#
+# We use `mock` library as that backports `AsyncMock` to Python 3.6
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
CONDITIONAL_REQUIREMENTS["dev"] = (
CONDITIONAL_REQUIREMENTS["lint"]
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 19e570ede2..a8154168be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -82,10 +82,19 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+ defer_to_thread,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+ DEFAULT_BATCH_SIZE_CALLBACK,
+ MIN_BATCH_SIZE_CALLBACK,
+ ON_UPDATE_CALLBACK,
+)
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@@ -311,6 +320,24 @@ class ModuleApi:
auth_checkers=auth_checkers,
)
+ def register_background_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Registers background update controller callbacks.
+
+ Added in Synapse v1.49.0.
+ """
+
+ for db in self._hs.get_datastores().databases:
+ db.updates.register_update_controller_callbacks(
+ on_update=on_update,
+ default_batch_size=default_batch_size,
+ min_batch_size=min_batch_size,
+ )
+
def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.
@@ -995,6 +1022,11 @@ class ModuleApi:
f,
)
+ async def sleep(self, seconds: float) -> None:
+ """Sleeps for the given number of seconds."""
+
+ await self._clock.sleep(seconds)
+
async def send_mail(
self,
recipient: str,
@@ -1149,6 +1181,26 @@ class ModuleApi:
return {key: state_events[event_id] for key, event_id in state_ids.items()}
+ async def defer_to_thread(
+ self,
+ f: Callable[..., T],
+ *args: Any,
+ **kwargs: Any,
+ ) -> T:
+ """Runs the given function in a separate thread from Synapse's thread pool.
+
+ Added in Synapse v1.49.0.
+
+ Args:
+ f: The function to run.
+ args: The function's arguments.
+ kwargs: The function's keyword arguments.
+
+ Returns:
+ The return value of the function once ran in a thread.
+ """
+ return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+)
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
from . import engines
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+ """A handler for a given background update.
+
+ Attributes:
+ callback: The function to call to make progress on the background
+ update.
+ oneshot: Wether the update is likely to happen all in one go, ignoring
+ the supplied target duration, e.g. index creation. This is used by
+ the update controller to help correctly schedule the update.
+ """
+
+ callback: Callable[[JsonDict, int], Awaitable[int]]
+ oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+ BACKGROUND_UPDATE_INTERVAL_MS = 1000
+ BACKGROUND_UPDATE_DURATION_MS = 100
+
+ def __init__(self, sleep: bool, clock: Clock):
+ self._sleep = sleep
+ self._clock = clock
+
+ async def __aenter__(self) -> int:
+ if self._sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+ return self.BACKGROUND_UPDATE_DURATION_MS
+
+ async def __aexit__(self, *exc) -> None:
+ pass
+
+
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
@@ -84,20 +133,22 @@ class BackgroundUpdater:
MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
- BACKGROUND_UPDATE_INTERVAL_MS = 1000
- BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self._database_name = database.name()
+
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
+ self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+ self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+ self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
- self._background_update_handlers: Dict[
- str, Callable[[JsonDict, int], Awaitable[int]]
- ] = {}
+ self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
# Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
+ def register_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Register callbacks from a module for each hook."""
+ if self._on_update_callback is not None:
+ logger.warning(
+ "More than one module tried to register callbacks for controlling"
+ " background updates. Only the callbacks registered by the first module"
+ " (in order of appearance in Synapse's configuration file) that tried to"
+ " do so will be called."
+ )
+
+ return
+
+ self._on_update_callback = on_update
+
+ if default_batch_size is not None:
+ self._default_batch_size_callback = default_batch_size
+
+ if min_batch_size is not None:
+ self._min_batch_size_callback = min_batch_size
+
+ def _get_context_manager_for_update(
+ self,
+ sleep: bool,
+ update_name: str,
+ database_name: str,
+ oneshot: bool,
+ ) -> AsyncContextManager[int]:
+ """Get a context manager to run a background update with.
+
+ If a module has registered a `update_handler` callback, use the context manager
+ it returns.
+
+ Otherwise, returns a context manager that will return a default value, optionally
+ sleeping if needed.
+
+ Args:
+ sleep: Whether we can sleep between updates.
+ update_name: The name of the update.
+ database_name: The name of the database the update is being run on.
+ oneshot: Whether the update will complete all in one go, e.g. index creation.
+ In such cases the returned target duration is ignored.
+
+ Returns:
+ The target duration in milliseconds that the background update should run for.
+
+ Note: this is a *target*, and an iteration may take substantially longer or
+ shorter.
+ """
+ if self._on_update_callback is not None:
+ return self._on_update_callback(update_name, database_name, oneshot)
+
+ return _BackgroundUpdateContextManager(sleep, self._clock)
+
+ async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+ """The batch size to use for the first iteration of a new background
+ update.
+ """
+ if self._default_batch_size_callback is not None:
+ return await self._default_batch_size_callback(update_name, database_name)
+
+ return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+ async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+ """A lower bound on the batch size of a new background update.
+
+ Used to ensure that progress is always made. Must be greater than 0.
+ """
+ if self._min_batch_size_callback is not None:
+ return await self._min_batch_size_callback(update_name, database_name)
+
+ return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -135,13 +263,8 @@ class BackgroundUpdater:
try:
logger.info("Starting background schema updates")
while self.enabled:
- if sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
try:
- result = await self.do_next_background_update(
- self.BACKGROUND_UPDATE_DURATION_MS
- )
+ result = await self.do_next_background_update(sleep)
except Exception:
logger.exception("Error doing update")
else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
return not update_exists
- async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+ async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
- desired_duration_ms: How long we want to spend updating.
+ sleep: Whether to limit how quickly we run background updates or
+ not.
+
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -252,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"]
- await self._do_background_update(desired_duration_ms)
+ # We have a background update to run, otherwise we would have returned
+ # early.
+ assert self._current_background_update is not None
+ update_info = self._background_update_handlers[self._current_background_update]
+
+ async with self._get_context_manager_for_update(
+ sleep=sleep,
+ update_name=self._current_background_update,
+ database_name=self._database_name,
+ oneshot=update_info.oneshot,
+ ) as desired_duration_ms:
+ await self._do_background_update(desired_duration_ms)
+
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
- update_handler = self._background_update_handlers[update_name]
+ update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name)
@@ -273,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
- batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+ batch_size = max(
+ batch_size,
+ await self._min_batch_size(update_name, self._database_name),
+ )
else:
- batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+ batch_size = await self._default_batch_size(
+ update_name, self._database_name
+ )
progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start
+ performance.update(items_updated, duration_ms)
+
logger.info(
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
batch_size,
)
- performance.update(items_updated, duration_ms)
-
return len(self._background_update_performance)
def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
- self._background_update_handlers[update_name] = update_handler
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ update_handler
+ )
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name)
return 1
- self.register_background_update_handler(update_name, updater)
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 90f800e564..f8cba7b645 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.auth_handler = hs.get_auth_handler()
+ self.store = hs.get_datastore()
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
@@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase):
self.hs.get_datastore().db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(
- self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
- by=0.1,
- )
+ self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
pushers = self.get_success(
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index cd5c60b65c..62f242baf6 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -135,7 +135,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
- self.reactor.pump([1.0, 1.0])
+ self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
"GET",
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a5f5ebad41..216d816d56 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,8 +1,11 @@
-from unittest.mock import Mock
+from mock import Mock
+
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
+from tests.test_utils import make_awaitable
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -20,10 +23,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def test_do_background_update(self):
# the time we claim it takes to update one item when running the update
- duration_ms = 4200
+ duration_ms = 10
# the target runtime for each bg update
- target_background_update_duration_ms = 5000000
+ target_background_update_duration_ms = 100
store = self.hs.get_datastore()
self.get_success(
@@ -48,10 +51,8 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
res = self.get_success(
- self.updates.do_next_background_update(
- target_background_update_duration_ms
- ),
- by=0.1,
+ self.updates.do_next_background_update(False),
+ by=0.01,
)
self.assertFalse(res)
@@ -74,16 +75,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = self.get_success(
- self.updates.do_next_background_update(target_background_update_duration_ms)
- )
+ result = self.get_success(self.updates.do_next_background_update(False))
self.assertTrue(result)
self.assertFalse(self.update_handler.called)
+
+
+class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
+
+ self.update_deferred = Deferred()
+ self.update_handler = Mock(return_value=self.update_deferred)
+ self.updates.register_background_update_handler(
+ "test_update", self.update_handler
+ )
+
+ # Mock out the AsyncContextManager
+ self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
+ self._update_ctx_manager.__aenter__ = Mock(
+ return_value=make_awaitable(None),
+ )
+ self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+
+ # Mock out the `update_handler` callback
+ self._on_update = Mock(return_value=self._update_ctx_manager)
+
+ # Define a default batch size value that's not the same as the internal default
+ # value (100).
+ self._default_batch_size = 500
+
+ # Register the callbacks with more mocks
+ self.hs.get_module_api().register_background_update_controller_callbacks(
+ on_update=self._on_update,
+ min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
+ default_batch_size=Mock(
+ return_value=make_awaitable(self._default_batch_size),
+ ),
+ )
+
+ def test_controller(self):
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": "{}"},
+ )
+ )
+
+ # Set the return value for the context manager.
+ enter_defer = Deferred()
+ self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
+
+ # Start the background update.
+ do_update_d = ensureDeferred(self.updates.do_next_background_update(True))
+
+ self.pump()
+
+ # `run_update` should have been called, but the update handler won't be
+ # called until the `enter_defer` (returned by `__aenter__`) is resolved.
+ self._on_update.assert_called_once_with(
+ "test_update",
+ "master",
+ False,
+ )
+ self.assertFalse(do_update_d.called)
+ self.assertFalse(self.update_deferred.called)
+
+ # Resolving the `enter_defer` should call the update handler, which then
+ # blocks.
+ enter_defer.callback(100)
+ self.pump()
+ self.update_handler.assert_called_once_with({}, self._default_batch_size)
+ self.assertFalse(self.update_deferred.called)
+ self._update_ctx_manager.__aexit__.assert_not_called()
+
+ # Resolving the update handler deferred should cause the
+ # `do_next_background_update` to finish and return
+ self.update_deferred.callback(100)
+ self.pump()
+ self._update_ctx_manager.__aexit__.assert_called()
+ self.get_success(do_update_d)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b31c5eb5ec..7b7f6c349e 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
@@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
):
iterations += 1
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 37cf7bb232..7f5b28aed8 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -23,6 +23,7 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.background_updates import _BackgroundUpdateHandler
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
with mock.patch.dict(
self.store.db_pool.updates._background_update_handlers,
- populate_user_directory_process_users=mocked_process_users,
+ populate_user_directory_process_users=_BackgroundUpdateHandler(
+ mocked_process_users,
+ ),
):
self._purge_and_rebuild_user_dir()
diff --git a/tests/unittest.py b/tests/unittest.py
index 165aafc574..eea0903f05 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,17 +331,16 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """
- Block until all background database updates have completed.
+ """Block until all background database updates have completed.
- Note that callers must ensure that's a store property created on the
+ Note that callers must ensure there's a store property created on the
testcase.
"""
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(False), by=0.1
)
def make_homeserver(self, reactor, clock):
@@ -500,8 +499,7 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates"):
- while not await stor.db_pool.updates.has_completed_background_updates():
- await stor.db_pool.updates.do_next_background_update(1)
+ self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
|