diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 19741ffcda..48e616ac74 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,7 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList, lru_cache
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
+ @defer.inlineCallbacks
+ def test_cache_uncached_args(self):
+ """
+ Only the arguments not named in uncached_args should matter to the cache
+
+ Note that this is identical to test_cache_num_args, but provides the
+ arguments differently.
+ """
+
+ class Cls:
+ # Note that it is important that this is not the last argument to
+ # test behaviour of skipping arguments properly.
+ @descriptors.cached(uncached_args=("arg2",))
+ def fn(self, arg1, arg2, arg3):
+ return self.mock(arg1, arg2, arg3)
+
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = yield obj.fn(1, 2, 3)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2, 3)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = yield obj.fn(2, 3, 4)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(2, 3, 4)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached; we should be able to vary
+ # the second argument and still get the cached result.
+ r = yield obj.fn(1, 4, 3)
+ self.assertEqual(r, "fish")
+ r = yield obj.fn(2, 5, 4)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_cache_kwargs(self):
+ """Test that keyword arguments are treated properly"""
+
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, kwarg1=2):
+ return self.mock(arg1, kwarg1=kwarg1)
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = yield obj.fn(1, kwarg1=2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, kwarg1=2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = yield obj.fn(1, kwarg1=3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, kwarg1=3)
+ obj.mock.reset_mock()
+
+ # the values should now be cached.
+ r = yield obj.fn(1, kwarg1=2)
+ self.assertEqual(r, "fish")
+ # We should be able to not provide kwarg1 and get the cached value back.
+ r = yield obj.fn(1)
+ self.assertEqual(r, "fish")
+ # Keyword arguments can be in any order.
+ r = yield obj.fn(kwarg1=2, arg1=1)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_not_called()
+
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work"""
@@ -415,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate()
top_invalidate.assert_called_once()
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ async def fn(self, arg1):
+ await complete_lookup
+ return str(arg1)
+
+ obj = Cls()
+
+ d1 = obj.fn(123)
+ d2 = obj.fn(123)
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ # Cancel `d1`, which is the lookup that caused `fn` to run.
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, "123")
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ async def fn(self, arg1):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return str(arg1)
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.fn(123)
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
@@ -656,7 +802,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
# we want this to behave like an asynchronous function
@@ -715,7 +861,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1) -> "Deferred[dict]":
return self.mock(args1)
@@ -758,7 +904,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1")
+ @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
await run_on_reactor()
@@ -787,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()
+
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await complete_lookup
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ d1 = obj.list_fn([123, 456])
+ d2 = obj.list_fn([123, 456, 789])
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.list_fn([123])
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 362014f4cb..e5bc416de1 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -13,6 +13,8 @@
# limitations under the License.
import traceback
+from parameterized import parameterized_class
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
@@ -23,10 +25,12 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
current_context,
+ make_deferred_yieldable,
)
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
+ delay_cancellation,
stop_cancellation,
timeout_deferred,
)
@@ -100,6 +104,34 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
+ def test_cancellation(self):
+ """Test that cancelling an observer does not affect other observers."""
+ origin_d: "Deferred[int]" = Deferred()
+ observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+ observer1 = observable.observe()
+ observer2 = observable.observe()
+ observer3 = observable.observe()
+
+ self.assertFalse(observer1.called)
+ self.assertFalse(observer2.called)
+ self.assertFalse(observer3.called)
+
+ # cancel the second observer
+ observer2.cancel()
+ self.assertFalse(observer1.called)
+ self.failureResultOf(observer2, CancelledError)
+ self.assertFalse(observer3.called)
+
+ # other observers resolve as normal
+ origin_d.callback(123)
+ self.assertEqual(observer1.result, 123, "observer 1 callback result")
+ self.assertEqual(observer3.result, 123, "observer 3 callback result")
+
+ # additional observers resolve as normal
+ observer4 = observable.observe()
+ self.assertEqual(observer4.result, 123, "observer 4 callback result")
+
class TimeoutDeferredTest(TestCase):
def setUp(self):
@@ -285,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase):
self.successResultOf(d2)
-class StopCancellationTests(TestCase):
- """Tests for the `stop_cancellation` function."""
+@parameterized_class(
+ ("wrapper",),
+ [("stop_cancellation",), ("delay_cancellation",)],
+)
+class CancellationWrapperTests(TestCase):
+ """Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
+
+ wrapper: str
+
+ def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
+ if self.wrapper == "stop_cancellation":
+ return stop_cancellation(deferred)
+ elif self.wrapper == "delay_cancellation":
+ return delay_cancellation(deferred)
+ else:
+ raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self):
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
- wrapper_deferred = stop_cancellation(deferred)
+ wrapper_deferred = self.wrap_deferred(deferred)
# Success should propagate through.
deferred.callback("success")
@@ -301,7 +347,7 @@ class StopCancellationTests(TestCase):
def test_failure(self):
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
- wrapper_deferred = stop_cancellation(deferred)
+ wrapper_deferred = self.wrap_deferred(deferred)
# Failure should propagate through.
deferred.errback(ValueError("abc"))
@@ -309,6 +355,10 @@ class StopCancellationTests(TestCase):
self.failureResultOf(wrapper_deferred, ValueError)
self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+class StopCancellationTests(TestCase):
+ """Tests for the `stop_cancellation` function."""
+
def test_cancellation(self):
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
@@ -319,11 +369,101 @@ class StopCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
- deferred.called, "Original `Deferred` was unexpectedly cancelled."
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
+ )
+
+ # Now make the original `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+
+class DelayCancellationTests(TestCase):
+ """Tests for the `delay_cancellation` function."""
+
+ def test_cancellation(self):
+ """Test that cancellation of the new `Deferred` waits for the original."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Cancel the new `Deferred`.
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
+ )
+
+ # Now make the original `Deferred` fail.
+ # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+ # in logs.
+ deferred.errback(ValueError("abc"))
+ self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
+ def test_suppresses_second_cancellation(self):
+ """Test that a second cancellation is suppressed.
+
+ Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
+ """
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Cancel the new `Deferred`, twice.
+ wrapper_deferred.cancel()
+ wrapper_deferred.cancel()
+ self.assertNoResult(wrapper_deferred)
+ self.assertFalse(
+ deferred.called, "Original `Deferred` was unexpectedly cancelled"
)
- # Now make the inner `Deferred` fail.
+ # Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+ # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+ self.failureResultOf(wrapper_deferred, CancelledError)
+
+ def test_propagates_cancelled_error(self):
+ """Test that a `CancelledError` from the original `Deferred` gets propagated."""
+ deferred: "Deferred[str]" = Deferred()
+ wrapper_deferred = delay_cancellation(deferred)
+
+ # Fail the original `Deferred` with a `CancelledError`.
+ cancelled_error = CancelledError()
+ deferred.errback(cancelled_error)
+
+ # The new `Deferred` should fail with exactly the same `CancelledError`.
+ self.assertTrue(wrapper_deferred.called)
+ self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
+
+ def test_preserves_logcontext(self):
+ """Test that logging contexts are preserved."""
+ blocking_d: "Deferred[None]" = Deferred()
+
+ async def inner():
+ await make_deferred_yieldable(blocking_d)
+
+ async def outer():
+ with LoggingContext("c") as c:
+ try:
+ await delay_cancellation(defer.ensureDeferred(inner()))
+ self.fail("`CancelledError` was not raised")
+ except CancelledError:
+ self.assertEqual(c, current_context())
+ # Succeed with no error, unless the logging context is wrong.
+
+ # Run and block inside `inner()`.
+ d = defer.ensureDeferred(outer())
+ self.assertEqual(SENTINEL_CONTEXT, current_context())
+
+ d.cancel()
+
+ # Now unblock. `outer()` will consume the `CancelledError` and check the
+ # logging context.
+ blocking_d.callback(None)
+ self.successResultOf(d)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 3c07252252..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: str):
+ def __init__(self, version: object):
self._version = version
@property
@@ -27,7 +27,10 @@ class DummyDistribution(metadata.Distribution):
old = DummyDistribution("0.1.2")
+old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
+new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here
@@ -65,6 +68,35 @@ class TestDependencyChecker(TestCase):
# should not raise
check_requirements()
+ def test_version_reported_as_none(self) -> None:
+ """Complain if importlib.metadata.version() returns None.
+
+ This shouldn't normally happen, but it was seen in the wild (#12223).
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(distribution_with_no_version):
+ self.assertRaises(DependencyException, check_requirements)
+
+ def test_checks_ignore_dev_dependencies(self) -> None:
+ """Bot generic and per-extra checks should ignore dev dependencies."""
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'mypy'"],
+ ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ # We're testing that none of these calls raise.
+ with self.mock_installed_package(None):
+ check_requirements()
+ check_requirements("cool-extra")
+ with self.mock_installed_package(old):
+ check_requirements()
+ check_requirements("cool-extra")
+ with self.mock_installed_package(new):
+ check_requirements()
+ check_requirements("cool-extra")
+
def test_generic_check_of_optional_dependency(self) -> None:
"""Complain if an optional package is old."""
with patch(
@@ -85,11 +117,28 @@ class TestDependencyChecker(TestCase):
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'cool-extra'"],
- ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}):
+ ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
with self.mock_installed_package(None):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(old):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(new):
# should not raise
+ check_requirements("cool-extra")
+
+ def test_release_candidates_satisfy_dependency(self) -> None:
+ """
+ Tests that release candidates count as far as satisfying a dependency
+ is concerned.
+ (Regression test, see #12176.)
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(old_release_candidate):
+ self.assertRaises(DependencyException, check_requirements)
+
+ with self.mock_installed_package(new_release_candidate):
+ # should not raise
check_requirements()
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 0774625b85..0c84226197 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import AsyncContextManager, Callable, Sequence, Tuple
+
from twisted.internet import defer
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from synapse.util.async_helpers import ReadWriteLock
@@ -21,87 +23,187 @@ from tests import unittest
class ReadWriteLockTestCase(unittest.TestCase):
- def _assert_called_before_not_after(self, lst, first_false):
- for i, d in enumerate(lst[:first_false]):
- self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+ def _start_reader_or_writer(
+ self,
+ read_or_write: Callable[[str], AsyncContextManager],
+ key: str,
+ return_value: str,
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a reader or writer which acquires the lock, blocks, then completes.
+
+ Args:
+ read_or_write: A function returning a context manager for a lock.
+ Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`.
+ key: The key to read or write.
+ return_value: A string that the reader or writer will resolve with when
+ done.
+
+ Returns:
+ A tuple of three `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the reader or writer
+ completes successfully.
+ * A `Deferred` that resolves once the reader or writer acquires the lock.
+ * A `Deferred` that blocks the reader or writer. Must be resolved by the
+ caller to allow the reader or writer to release the lock and complete.
+ """
+ acquired_d: "Deferred[None]" = Deferred()
+ unblock_d: "Deferred[None]" = Deferred()
+
+ async def reader_or_writer():
+ async with read_or_write(key):
+ acquired_d.callback(None)
+ await unblock_d
+ return return_value
+
+ d = defer.ensureDeferred(reader_or_writer())
+ return d, acquired_d, unblock_d
+
+ def _start_blocking_reader(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a reader which acquires the lock, blocks, then releases the lock.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments
+ and return values.
+ """
+ return self._start_reader_or_writer(rwlock.read, key, return_value)
+
+ def _start_blocking_writer(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+ """Starts a writer which acquires the lock, blocks, then releases the lock.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments
+ and return values.
+ """
+ return self._start_reader_or_writer(rwlock.write, key, return_value)
+
+ def _start_nonblocking_reader(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+ """Starts a reader which acquires the lock, then releases it immediately.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+ Returns:
+ A tuple of two `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the reader completes
+ successfully.
+ * A `Deferred` that resolves once the reader acquires the lock.
+ """
+ d, acquired_d, unblock_d = self._start_reader_or_writer(
+ rwlock.read, key, return_value
+ )
+ unblock_d.callback(None)
+ return d, acquired_d
+
+ def _start_nonblocking_writer(
+ self, rwlock: ReadWriteLock, key: str, return_value: str
+ ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+ """Starts a writer which acquires the lock, then releases it immediately.
+
+ See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+ Returns:
+ A tuple of two `Deferred`s:
+ * A `Deferred` that resolves with `return_value` once the writer completes
+ successfully.
+ * A `Deferred` that resolves once the writer acquires the lock.
+ """
+ d, acquired_d, unblock_d = self._start_reader_or_writer(
+ rwlock.write, key, return_value
+ )
+ unblock_d.callback(None)
+ return d, acquired_d
+
+ def _assert_first_n_resolved(
+ self, deferreds: Sequence["defer.Deferred[None]"], n: int
+ ) -> None:
+ """Assert that exactly the first n `Deferred`s in the given list are resolved.
- for i, d in enumerate(lst[first_false:]):
+ Args:
+ deferreds: The list of `Deferred`s to be checked.
+ n: The number of `Deferred`s at the start of `deferreds` that should be
+ resolved.
+ """
+ for i, d in enumerate(deferreds[:n]):
+ self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i)
+
+ for i, d in enumerate(deferreds[n:]):
self.assertFalse(
- d.called, msg="%d was unexpectedly true" % (i + first_false)
+ d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
def test_rwlock(self):
rwlock = ReadWriteLock()
-
- key = object()
+ key = "key"
ds = [
- rwlock.read(key), # 0
- rwlock.read(key), # 1
- rwlock.write(key), # 2
- rwlock.write(key), # 3
- rwlock.read(key), # 4
- rwlock.read(key), # 5
- rwlock.write(key), # 6
+ self._start_blocking_reader(rwlock, key, "0"),
+ self._start_blocking_reader(rwlock, key, "1"),
+ self._start_blocking_writer(rwlock, key, "2"),
+ self._start_blocking_writer(rwlock, key, "3"),
+ self._start_blocking_reader(rwlock, key, "4"),
+ self._start_blocking_reader(rwlock, key, "5"),
+ self._start_blocking_writer(rwlock, key, "6"),
]
- ds = [defer.ensureDeferred(d) for d in ds]
+ # `Deferred`s that resolve when each reader or writer acquires the lock.
+ acquired_ds = [acquired_d for _, acquired_d, _ in ds]
+ # `Deferred`s that will trigger the release of locks when resolved.
+ release_ds = [release_d for _, _, release_d in ds]
- self._assert_called_before_not_after(ds, 2)
+ # The first two readers should acquire their locks.
+ self._assert_first_n_resolved(acquired_ds, 2)
- with ds[0].result:
- self._assert_called_before_not_after(ds, 2)
- self._assert_called_before_not_after(ds, 2)
+ # Release one of the read locks. The next writer should not acquire the lock,
+ # because there is another reader holding the lock.
+ self._assert_first_n_resolved(acquired_ds, 2)
+ release_ds[0].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 2)
- with ds[1].result:
- self._assert_called_before_not_after(ds, 2)
- self._assert_called_before_not_after(ds, 3)
+ # Release the other read lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 2)
+ release_ds[1].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 3)
- with ds[2].result:
- self._assert_called_before_not_after(ds, 3)
- self._assert_called_before_not_after(ds, 4)
+ # Release the write lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 3)
+ release_ds[2].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 4)
- with ds[3].result:
- self._assert_called_before_not_after(ds, 4)
- self._assert_called_before_not_after(ds, 6)
+ # Release the write lock. The next two readers should acquire locks.
+ self._assert_first_n_resolved(acquired_ds, 4)
+ release_ds[3].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 6)
- with ds[5].result:
- self._assert_called_before_not_after(ds, 6)
- self._assert_called_before_not_after(ds, 6)
+ # Release one of the read locks. The next writer should not acquire the lock,
+ # because there is another reader holding the lock.
+ self._assert_first_n_resolved(acquired_ds, 6)
+ release_ds[5].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 6)
- with ds[4].result:
- self._assert_called_before_not_after(ds, 6)
- self._assert_called_before_not_after(ds, 7)
+ # Release the other read lock. The next writer should acquire the lock.
+ self._assert_first_n_resolved(acquired_ds, 6)
+ release_ds[4].callback(None)
+ self._assert_first_n_resolved(acquired_ds, 7)
- with ds[6].result:
- pass
+ # Release the write lock.
+ release_ds[6].callback(None)
- d = defer.ensureDeferred(rwlock.write(key))
- self.assertTrue(d.called)
- with d.result:
- pass
+ # Acquire and release the write and read locks one last time for good measure.
+ _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer")
+ self.assertTrue(acquired_d.called)
- d = defer.ensureDeferred(rwlock.read(key))
- self.assertTrue(d.called)
- with d.result:
- pass
+ _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
+ self.assertTrue(acquired_d.called)
def test_lock_handoff_to_nonblocking_writer(self):
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
- unblock: "Deferred[None]" = Deferred()
-
- async def blocking_write():
- with await rwlock.write(key):
- await unblock
-
- async def nonblocking_write():
- with await rwlock.write(key):
- pass
-
- d1 = defer.ensureDeferred(blocking_write())
- d2 = defer.ensureDeferred(nonblocking_write())
+ d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed")
+ d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
self.assertFalse(d1.called)
self.assertFalse(d2.called)
@@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase):
self.assertTrue(d2.called)
# The `ReadWriteLock` should operate as normal.
- d3 = defer.ensureDeferred(nonblocking_write())
+ d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
+
+ def test_cancellation_while_holding_read_lock(self):
+ """Test cancellation while holding a read lock.
+
+ A waiting writer should be given the lock when the reader holding the lock is
+ cancelled.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A reader takes the lock and blocks.
+ reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed")
+
+ # 2. A writer waits for the reader to complete.
+ writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed")
+ self.assertFalse(writer_d.called)
+
+ # 3. The reader is cancelled.
+ reader_d.cancel()
+ self.failureResultOf(reader_d, CancelledError)
+
+ # 4. The writer should take the lock and complete.
+ self.assertTrue(
+ writer_d.called, "Writer is stuck waiting for a cancelled reader"
+ )
+ self.assertEqual("write completed", self.successResultOf(writer_d))
+
+ def test_cancellation_while_holding_write_lock(self):
+ """Test cancellation while holding a write lock.
+
+ A waiting reader should be given the lock when the writer holding the lock is
+ cancelled.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A writer takes the lock and blocks.
+ writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed")
+
+ # 2. A reader waits for the writer to complete.
+ reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+ self.assertFalse(reader_d.called)
+
+ # 3. The writer is cancelled.
+ writer_d.cancel()
+ self.failureResultOf(writer_d, CancelledError)
+
+ # 4. The reader should take the lock and complete.
+ self.assertTrue(
+ reader_d.called, "Reader is stuck waiting for a cancelled writer"
+ )
+ self.assertEqual("read completed", self.successResultOf(reader_d))
+
+ def test_cancellation_while_waiting_for_read_lock(self):
+ """Test cancellation while waiting for a read lock.
+
+ Tests that cancelling a waiting reader:
+ * does not cancel the writer it is waiting on
+ * does not cancel the next writer waiting on it
+ * does not allow the next writer to acquire the lock before an earlier writer
+ has finished
+ * does not keep the next writer waiting indefinitely
+
+ These correspond to the asserts with explicit messages.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A writer takes the lock and blocks.
+ writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+ rwlock, key, "write 1 completed"
+ )
+
+ # 2. A reader waits for the first writer to complete.
+ # This reader will be cancelled later.
+ reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+ self.assertFalse(reader_d.called)
+
+ # 3. A second writer waits for both the first writer and the reader to complete.
+ writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+ self.assertFalse(writer2_d.called)
+
+ # 4. The waiting reader is cancelled.
+ # Neither of the writers should be cancelled.
+ # The second writer should still be waiting, but only on the first writer.
+ reader_d.cancel()
+ self.failureResultOf(reader_d, CancelledError)
+ self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+ self.assertFalse(
+ writer2_d.called,
+ "Second writer was unexpectedly cancelled or given the lock before the "
+ "first writer finished",
+ )
+
+ # 5. Unblock the first writer, which should complete.
+ unblock_writer1.callback(None)
+ self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+ # 6. The second writer should take the lock and complete.
+ self.assertTrue(
+ writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
+ )
+ self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
+
+ def test_cancellation_while_waiting_for_write_lock(self):
+ """Test cancellation while waiting for a write lock.
+
+ Tests that cancelling a waiting writer:
+ * does not cancel the reader or writer it is waiting on
+ * does not cancel the next writer waiting on it
+ * does not allow the next writer to acquire the lock before an earlier reader
+ and writer have finished
+ * does not keep the next writer waiting indefinitely
+
+ These correspond to the asserts with explicit messages.
+ """
+ rwlock = ReadWriteLock()
+ key = "key"
+
+ # 1. A reader takes the lock and blocks.
+ reader_d, _, unblock_reader = self._start_blocking_reader(
+ rwlock, key, "read completed"
+ )
+
+ # 2. A writer waits for the reader to complete.
+ writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+ rwlock, key, "write 1 completed"
+ )
+
+ # 3. A second writer waits for both the reader and first writer to complete.
+ # This writer will be cancelled later.
+ writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+ self.assertFalse(writer2_d.called)
+
+ # 4. A third writer waits for the second writer to complete.
+ writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
+ self.assertFalse(writer3_d.called)
+
+ # 5. The second writer is cancelled, but continues waiting for the lock.
+ # The reader, first writer and third writer should not be cancelled.
+ # The first writer should still be waiting on the reader.
+ # The third writer should still be waiting on the second writer.
+ writer2_d.cancel()
+ self.assertNoResult(writer2_d)
+ self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
+ self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+ self.assertFalse(
+ writer3_d.called,
+ "Third writer was unexpectedly cancelled or given the lock before the first "
+ "writer finished",
+ )
+
+ # 6. Unblock the reader, which should complete.
+ # The first writer should be given the lock and block.
+ # The third writer should still be waiting on the second writer.
+ unblock_reader.callback(None)
+ self.assertEqual("read completed", self.successResultOf(reader_d))
+ self.assertNoResult(writer2_d)
+ self.assertFalse(
+ writer3_d.called,
+ "Third writer was unexpectedly given the lock before the first writer "
+ "finished",
+ )
+
+ # 7. Unblock the first writer, which should complete.
+ unblock_writer1.callback(None)
+ self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+ # 8. The second writer should take the lock and release it immediately, since it
+ # has been cancelled.
+ self.failureResultOf(writer2_d, CancelledError)
+
+ # 9. The third writer should take the lock and complete.
+ self.assertTrue(
+ writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
+ )
+ self.assertEqual("write 3 completed", self.successResultOf(writer3_d))
|