From 75574726a766f09d955c05672d400c65cb341810 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 4 Mar 2022 15:37:02 +0000 Subject: Add type hints for `ObservableDeferred` attributes (#12159) Signed-off-by: Sean Quah --- synapse/util/async_helpers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'synapse/util/async_helpers.py') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 60c03a66fd..a9f67dcbac 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -40,7 +40,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager +from typing_extensions import ContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -96,6 +96,10 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] + _deferred: "defer.Deferred[_T]" + _observers: Union[List["defer.Deferred[_T]"], Tuple[()]] + _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]] + def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) @@ -158,12 +162,14 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): effect the underlying deferred. """ if not self._result: + assert isinstance(self._observers, list) d: "defer.Deferred[_T]" = defer.Deferred() self._observers.append(d) return d + elif self._result[0]: + return defer.succeed(self._result[1]) else: - success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return defer.fail(self._result[1]) def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers @@ -175,6 +181,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): return self._result is not None and self._result[0] is True def get_result(self) -> Union[_T, Failure]: + if self._result is None: + raise ValueError(f"{self!r} has no result yet") return self._result[1] def __getattr__(self, name: str) -> Any: -- cgit 1.5.1 From 90b2327066d2343faa86c464a182b6f3c4422ecd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:52:15 +0000 Subject: Add `delay_cancellation` utility function (#12180) `delay_cancellation` behaves like `stop_cancellation`, except it delays `CancelledError`s until the original `Deferred` resolves. This is handy for unifying cleanup paths and ensuring that uncancelled coroutines don't use finished logcontexts. Signed-off-by: Sean Quah --- changelog.d/12180.misc | 1 + synapse/util/async_helpers.py | 48 +++++++++++++-- tests/util/test_async_helpers.py | 124 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12180.misc (limited to 'synapse/util/async_helpers.py') diff --git a/changelog.d/12180.misc b/changelog.d/12180.misc new file mode 100644 index 0000000000..7a347352fd --- /dev/null +++ b/changelog.d/12180.misc @@ -0,0 +1 @@ +Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a9f67dcbac..69c8c1baa9 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -686,12 +686,48 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": Synapse logcontext rules. Returns: - A new `Deferred`, which will contain the result of the original `Deferred`, - but will not propagate cancellation through to the original. When cancelled, - the new `Deferred` will fail with a `CancelledError` and will not follow the - Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap - the new `Deferred`. + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will fail with a `CancelledError`. + + The new `Deferred` will not follow the Synapse logcontext rules and should be + wrapped with `make_deferred_yieldable`. + """ + new_deferred: "defer.Deferred[T]" = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred + + +def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Delay cancellation of a `Deferred` until it resolves. + + Has the same effect as `stop_cancellation`, but the returned `Deferred` will not + resolve with a `CancelledError` until the original `Deferred` resolves. + + Args: + deferred: The `Deferred` to protect against cancellation. May optionally follow + the Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`. + The new `Deferred` will not propagate cancellation through to the original. + When cancelled, the new `Deferred` will wait until the original `Deferred` + resolves before failing with a `CancelledError`. + + The new `Deferred` will follow the Synapse logcontext rules if `deferred` + follows the Synapse logcontext rules. Otherwise the new `Deferred` should be + wrapped with `make_deferred_yieldable`. """ - new_deferred: defer.Deferred[T] = defer.Deferred() + + def handle_cancel(new_deferred: "defer.Deferred[T]") -> None: + # before the new deferred is cancelled, we `pause` it to stop the cancellation + # propagating. we then `unpause` it once the wrapped deferred completes, to + # propagate the exception. + new_deferred.pause() + new_deferred.errback(Failure(CancelledError())) + + deferred.addBoth(lambda _: new_deferred.unpause()) + + new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel) deferred.chainDeferred(new_deferred) return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index ff53ce114b..e5bc416de1 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -13,6 +13,8 @@ # limitations under the License. import traceback +from parameterized import parameterized_class + from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.internet.task import Clock @@ -23,10 +25,12 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, current_context, + make_deferred_yieldable, ) from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + delay_cancellation, stop_cancellation, timeout_deferred, ) @@ -313,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase): self.successResultOf(d2) -class StopCancellationTests(TestCase): - """Tests for the `stop_cancellation` function.""" +@parameterized_class( + ("wrapper",), + [("stop_cancellation",), ("delay_cancellation",)], +) +class CancellationWrapperTests(TestCase): + """Common tests for the `stop_cancellation` and `delay_cancellation` functions.""" + + wrapper: str + + def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]": + if self.wrapper == "stop_cancellation": + return stop_cancellation(deferred) + elif self.wrapper == "delay_cancellation": + return delay_cancellation(deferred) + else: + raise ValueError(f"Unsupported wrapper type: {self.wrapper}") def test_succeed(self): """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Success should propagate through. deferred.callback("success") @@ -329,7 +347,7 @@ class StopCancellationTests(TestCase): def test_failure(self): """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() - wrapper_deferred = stop_cancellation(deferred) + wrapper_deferred = self.wrap_deferred(deferred) # Failure should propagate through. deferred.errback(ValueError("abc")) @@ -337,6 +355,10 @@ class StopCancellationTests(TestCase): self.failureResultOf(wrapper_deferred, ValueError) self.assertIsNone(deferred.result, "`Failure` was not consumed") + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + def test_cancellation(self): """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() @@ -347,11 +369,101 @@ class StopCancellationTests(TestCase): self.assertTrue(wrapper_deferred.called) self.failureResultOf(wrapper_deferred, CancelledError) self.assertFalse( - deferred.called, "Original `Deferred` was unexpectedly cancelled." + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + +class DelayCancellationTests(TestCase): + """Tests for the `delay_cancellation` function.""" + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` waits for the original.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" + ) + + # Now make the original `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_suppresses_second_cancellation(self): + """Test that a second cancellation is suppressed. + + Identical to `test_cancellation` except the new `Deferred` is cancelled twice. + """ + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Cancel the new `Deferred`, twice. + wrapper_deferred.cancel() + wrapper_deferred.cancel() + self.assertNoResult(wrapper_deferred) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled" ) - # Now make the inner `Deferred` fail. + # Now make the original `Deferred` fail. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed # in logs. deferred.errback(ValueError("abc")) self.assertIsNone(deferred.result, "`Failure` was not consumed") + + # Now that the original `Deferred` has failed, we should get a `CancelledError`. + self.failureResultOf(wrapper_deferred, CancelledError) + + def test_propagates_cancelled_error(self): + """Test that a `CancelledError` from the original `Deferred` gets propagated.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = delay_cancellation(deferred) + + # Fail the original `Deferred` with a `CancelledError`. + cancelled_error = CancelledError() + deferred.errback(cancelled_error) + + # The new `Deferred` should fail with exactly the same `CancelledError`. + self.assertTrue(wrapper_deferred.called) + self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) + + def test_preserves_logcontext(self): + """Test that logging contexts are preserved.""" + blocking_d: "Deferred[None]" = Deferred() + + async def inner(): + await make_deferred_yieldable(blocking_d) + + async def outer(): + with LoggingContext("c") as c: + try: + await delay_cancellation(defer.ensureDeferred(inner())) + self.fail("`CancelledError` was not raised") + except CancelledError: + self.assertEqual(c, current_context()) + # Succeed with no error, unless the logging context is wrong. + + # Run and block inside `inner()`. + d = defer.ensureDeferred(outer()) + self.assertEqual(SENTINEL_CONTEXT, current_context()) + + d.cancel() + + # Now unblock. `outer()` will consume the `CancelledError` and check the + # logging context. + blocking_d.callback(None) + self.successResultOf(d) -- cgit 1.5.1 From 605d161d7d585847fd1bb98d14d5281daeac8e86 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 14 Mar 2022 18:49:07 +0000 Subject: Add cancellation support to `ReadWriteLock` (#12120) Also convert `ReadWriteLock` to use async context managers. Signed-off-by: Sean Quah --- changelog.d/12120.misc | 1 + synapse/handlers/pagination.py | 8 +- synapse/util/async_helpers.py | 71 ++++---- tests/util/test_rwlock.py | 395 +++++++++++++++++++++++++++++++++++------ 4 files changed, 382 insertions(+), 93 deletions(-) create mode 100644 changelog.d/12120.misc (limited to 'synapse/util/async_helpers.py') diff --git a/changelog.d/12120.misc b/changelog.d/12120.misc new file mode 100644 index 0000000000..3603096500 --- /dev/null +++ b/changelog.d/12120.misc @@ -0,0 +1 @@ +Add support for cancellation to `ReadWriteLock`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 183fabcfc0..60059fec3e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -350,7 +350,7 @@ class PaginationHandler: """ self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) @@ -406,7 +406,7 @@ class PaginationHandler: room_id: room to be purged force: set true to skip checking for joined users. """ - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -448,7 +448,7 @@ class PaginationHandler: room_token = from_token.room_key - with await self.pagination_lock.read(room_id): + async with self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -615,7 +615,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: - with await self.pagination_lock.write(room_id): + async with self.pagination_lock.write(room_id): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 69c8c1baa9..6a8e844d63 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,9 +18,10 @@ import collections import inspect import itertools import logging -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import ( Any, + AsyncIterator, Awaitable, Callable, Collection, @@ -40,7 +41,7 @@ from typing import ( ) import attr -from typing_extensions import ContextManager, Literal +from typing_extensions import AsyncContextManager, Literal from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -491,7 +492,7 @@ class ReadWriteLock: Example: - with await read_write_lock.read("test_key"): + async with read_write_lock.read("test_key"): # do some work """ @@ -514,22 +515,24 @@ class ReadWriteLock: # Latest writer queued self.key_to_current_writer: Dict[str, defer.Deferred] = {} - async def read(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def read(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - curr_readers.add(new_defer) + curr_readers.add(new_defer) - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - if curr_writer: - await make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager() -> Iterator[None]: try: + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + # May raise a `CancelledError` if the `Deferred` wrapping us is + # cancelled. The `Deferred` we are waiting on must not be cancelled, + # since we do not own it. + if curr_writer: + await make_deferred_yieldable(stop_cancellation(curr_writer)) yield finally: with PreserveLoggingContext(): @@ -538,29 +541,35 @@ class ReadWriteLock: return _ctx_manager() - async def write(self, key: str) -> ContextManager: - new_defer: "defer.Deferred[None]" = defer.Deferred() + def write(self, key: str) -> AsyncContextManager: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + new_defer: "defer.Deferred[None]" = defer.Deferred() - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer + # We can clear the list of current readers since `new_defer` waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer - await make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager() -> Iterator[None]: + to_wait_on_defer = defer.gatherResults(to_wait_on) try: + # Wait for all current readers and the latest writer to finish. + # May raise a `CancelledError` immediately after the wait if the + # `Deferred` wrapping us is cancelled. We must only release the lock + # once we have acquired it, hence the use of `delay_cancellation` + # rather than `stop_cancellation`. + await make_deferred_yieldable(delay_cancellation(to_wait_on_defer)) yield finally: + # Release the lock. with PreserveLoggingContext(): new_defer.callback(None) # `self.key_to_current_writer[key]` may be missing if there was another diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 0774625b85..0c84226197 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncContextManager, Callable, Sequence, Tuple + from twisted.internet import defer -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from synapse.util.async_helpers import ReadWriteLock @@ -21,87 +23,187 @@ from tests import unittest class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): - for i, d in enumerate(lst[:first_false]): - self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + def _start_reader_or_writer( + self, + read_or_write: Callable[[str], AsyncContextManager], + key: str, + return_value: str, + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader or writer which acquires the lock, blocks, then completes. + + Args: + read_or_write: A function returning a context manager for a lock. + Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`. + key: The key to read or write. + return_value: A string that the reader or writer will resolve with when + done. + + Returns: + A tuple of three `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader or writer + completes successfully. + * A `Deferred` that resolves once the reader or writer acquires the lock. + * A `Deferred` that blocks the reader or writer. Must be resolved by the + caller to allow the reader or writer to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def reader_or_writer(): + async with read_or_write(key): + acquired_d.callback(None) + await unblock_d + return return_value + + d = defer.ensureDeferred(reader_or_writer()) + return d, acquired_d, unblock_d + + def _start_blocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a reader which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.read, key, return_value) + + def _start_blocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]: + """Starts a writer which acquires the lock, blocks, then releases the lock. + + See the docstring for `_start_reader_or_writer` for details about the arguments + and return values. + """ + return self._start_reader_or_writer(rwlock.write, key, return_value) + + def _start_nonblocking_reader( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a reader which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the reader completes + successfully. + * A `Deferred` that resolves once the reader acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.read, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _start_nonblocking_writer( + self, rwlock: ReadWriteLock, key: str, return_value: str + ) -> Tuple["Deferred[str]", "Deferred[None]"]: + """Starts a writer which acquires the lock, then releases it immediately. + + See the docstring for `_start_reader_or_writer` for details about the arguments. + + Returns: + A tuple of two `Deferred`s: + * A `Deferred` that resolves with `return_value` once the writer completes + successfully. + * A `Deferred` that resolves once the writer acquires the lock. + """ + d, acquired_d, unblock_d = self._start_reader_or_writer( + rwlock.write, key, return_value + ) + unblock_d.callback(None) + return d, acquired_d + + def _assert_first_n_resolved( + self, deferreds: Sequence["defer.Deferred[None]"], n: int + ) -> None: + """Assert that exactly the first n `Deferred`s in the given list are resolved. - for i, d in enumerate(lst[first_false:]): + Args: + deferreds: The list of `Deferred`s to be checked. + n: The number of `Deferred`s at the start of `deferreds` that should be + resolved. + """ + for i, d in enumerate(deferreds[:n]): + self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i) + + for i, d in enumerate(deferreds[n:]): self.assertFalse( - d.called, msg="%d was unexpectedly true" % (i + first_false) + d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) def test_rwlock(self): rwlock = ReadWriteLock() - - key = object() + key = "key" ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 - rwlock.write(key), # 2 - rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 - rwlock.write(key), # 6 + self._start_blocking_reader(rwlock, key, "0"), + self._start_blocking_reader(rwlock, key, "1"), + self._start_blocking_writer(rwlock, key, "2"), + self._start_blocking_writer(rwlock, key, "3"), + self._start_blocking_reader(rwlock, key, "4"), + self._start_blocking_reader(rwlock, key, "5"), + self._start_blocking_writer(rwlock, key, "6"), ] - ds = [defer.ensureDeferred(d) for d in ds] + # `Deferred`s that resolve when each reader or writer acquires the lock. + acquired_ds = [acquired_d for _, acquired_d, _ in ds] + # `Deferred`s that will trigger the release of locks when resolved. + release_ds = [release_d for _, _, release_d in ds] - self._assert_called_before_not_after(ds, 2) + # The first two readers should acquire their locks. + self._assert_first_n_resolved(acquired_ds, 2) - with ds[0].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 2) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[0].callback(None) + self._assert_first_n_resolved(acquired_ds, 2) - with ds[1].result: - self._assert_called_before_not_after(ds, 2) - self._assert_called_before_not_after(ds, 3) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 2) + release_ds[1].callback(None) + self._assert_first_n_resolved(acquired_ds, 3) - with ds[2].result: - self._assert_called_before_not_after(ds, 3) - self._assert_called_before_not_after(ds, 4) + # Release the write lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 3) + release_ds[2].callback(None) + self._assert_first_n_resolved(acquired_ds, 4) - with ds[3].result: - self._assert_called_before_not_after(ds, 4) - self._assert_called_before_not_after(ds, 6) + # Release the write lock. The next two readers should acquire locks. + self._assert_first_n_resolved(acquired_ds, 4) + release_ds[3].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[5].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 6) + # Release one of the read locks. The next writer should not acquire the lock, + # because there is another reader holding the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[5].callback(None) + self._assert_first_n_resolved(acquired_ds, 6) - with ds[4].result: - self._assert_called_before_not_after(ds, 6) - self._assert_called_before_not_after(ds, 7) + # Release the other read lock. The next writer should acquire the lock. + self._assert_first_n_resolved(acquired_ds, 6) + release_ds[4].callback(None) + self._assert_first_n_resolved(acquired_ds, 7) - with ds[6].result: - pass + # Release the write lock. + release_ds[6].callback(None) - d = defer.ensureDeferred(rwlock.write(key)) - self.assertTrue(d.called) - with d.result: - pass + # Acquire and release the write and read locks one last time for good measure. + _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer") + self.assertTrue(acquired_d.called) - d = defer.ensureDeferred(rwlock.read(key)) - self.assertTrue(d.called) - with d.result: - pass + _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") + self.assertTrue(acquired_d.called) def test_lock_handoff_to_nonblocking_writer(self): """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" - unblock: "Deferred[None]" = Deferred() - - async def blocking_write(): - with await rwlock.write(key): - await unblock - - async def nonblocking_write(): - with await rwlock.write(key): - pass - - d1 = defer.ensureDeferred(blocking_write()) - d2 = defer.ensureDeferred(nonblocking_write()) + d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed") + d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") self.assertFalse(d1.called) self.assertFalse(d2.called) @@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase): self.assertTrue(d2.called) # The `ReadWriteLock` should operate as normal. - d3 = defer.ensureDeferred(nonblocking_write()) + d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) + + def test_cancellation_while_holding_read_lock(self): + """Test cancellation while holding a read lock. + + A waiting writer should be given the lock when the reader holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed") + + # 2. A writer waits for the reader to complete. + writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed") + self.assertFalse(writer_d.called) + + # 3. The reader is cancelled. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + + # 4. The writer should take the lock and complete. + self.assertTrue( + writer_d.called, "Writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write completed", self.successResultOf(writer_d)) + + def test_cancellation_while_holding_write_lock(self): + """Test cancellation while holding a write lock. + + A waiting reader should be given the lock when the writer holding the lock is + cancelled. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed") + + # 2. A reader waits for the writer to complete. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. The writer is cancelled. + writer_d.cancel() + self.failureResultOf(writer_d, CancelledError) + + # 4. The reader should take the lock and complete. + self.assertTrue( + reader_d.called, "Reader is stuck waiting for a cancelled writer" + ) + self.assertEqual("read completed", self.successResultOf(reader_d)) + + def test_cancellation_while_waiting_for_read_lock(self): + """Test cancellation while waiting for a read lock. + + Tests that cancelling a waiting reader: + * does not cancel the writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier writer + has finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A writer takes the lock and blocks. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 2. A reader waits for the first writer to complete. + # This reader will be cancelled later. + reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed") + self.assertFalse(reader_d.called) + + # 3. A second writer waits for both the first writer and the reader to complete. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. The waiting reader is cancelled. + # Neither of the writers should be cancelled. + # The second writer should still be waiting, but only on the first writer. + reader_d.cancel() + self.failureResultOf(reader_d, CancelledError) + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer2_d.called, + "Second writer was unexpectedly cancelled or given the lock before the " + "first writer finished", + ) + + # 5. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 6. The second writer should take the lock and complete. + self.assertTrue( + writer2_d.called, "Second writer is stuck waiting for a cancelled reader" + ) + self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) + + def test_cancellation_while_waiting_for_write_lock(self): + """Test cancellation while waiting for a write lock. + + Tests that cancelling a waiting writer: + * does not cancel the reader or writer it is waiting on + * does not cancel the next writer waiting on it + * does not allow the next writer to acquire the lock before an earlier reader + and writer have finished + * does not keep the next writer waiting indefinitely + + These correspond to the asserts with explicit messages. + """ + rwlock = ReadWriteLock() + key = "key" + + # 1. A reader takes the lock and blocks. + reader_d, _, unblock_reader = self._start_blocking_reader( + rwlock, key, "read completed" + ) + + # 2. A writer waits for the reader to complete. + writer1_d, _, unblock_writer1 = self._start_blocking_writer( + rwlock, key, "write 1 completed" + ) + + # 3. A second writer waits for both the reader and first writer to complete. + # This writer will be cancelled later. + writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed") + self.assertFalse(writer2_d.called) + + # 4. A third writer waits for the second writer to complete. + writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") + self.assertFalse(writer3_d.called) + + # 5. The second writer is cancelled, but continues waiting for the lock. + # The reader, first writer and third writer should not be cancelled. + # The first writer should still be waiting on the reader. + # The third writer should still be waiting on the second writer. + writer2_d.cancel() + self.assertNoResult(writer2_d) + self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled") + self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled") + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly cancelled or given the lock before the first " + "writer finished", + ) + + # 6. Unblock the reader, which should complete. + # The first writer should be given the lock and block. + # The third writer should still be waiting on the second writer. + unblock_reader.callback(None) + self.assertEqual("read completed", self.successResultOf(reader_d)) + self.assertNoResult(writer2_d) + self.assertFalse( + writer3_d.called, + "Third writer was unexpectedly given the lock before the first writer " + "finished", + ) + + # 7. Unblock the first writer, which should complete. + unblock_writer1.callback(None) + self.assertEqual("write 1 completed", self.successResultOf(writer1_d)) + + # 8. The second writer should take the lock and release it immediately, since it + # has been cancelled. + self.failureResultOf(writer2_d, CancelledError) + + # 9. The third writer should take the lock and complete. + self.assertTrue( + writer3_d.called, "Third writer is stuck waiting for a cancelled writer" + ) + self.assertEqual("write 3 completed", self.successResultOf(writer3_d)) -- cgit 1.5.1