diff --git a/changelog.d/12183.misc b/changelog.d/12183.misc
new file mode 100644
index 0000000000..dd441bb64f
--- /dev/null
+++ b/changelog.d/12183.misc
@@ -0,0 +1 @@
+Add cancellation support to `@cached` and `@cachedList` decorators.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c3c5c16db9..eda92d864d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -41,6 +41,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
@@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
+ # We started a new call to `self.orig`, so we must always wait for it to
+ # complete. Otherwise we might mark our current logging context as
+ # finished while `self.orig` is still using it in the background.
+ ret = delay_cancellation(ret)
+
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
+ if missing:
+ # We started a new call to `self.orig`, so we must always wait for it to
+ # complete. Otherwise we might mark our current logging context as
+ # finished while `self.orig` is still using it in the background.
+ d = delay_cancellation(d)
return make_deferred_yieldable(d)
else:
return defer.succeed(results)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 6a4b17527a..48e616ac74 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,7 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList, lru_cache
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -493,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate()
top_invalidate.assert_called_once()
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ async def fn(self, arg1):
+ await complete_lookup
+ return str(arg1)
+
+ obj = Cls()
+
+ d1 = obj.fn(123)
+ d2 = obj.fn(123)
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ # Cancel `d1`, which is the lookup that caused `fn` to run.
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, "123")
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ async def fn(self, arg1):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return str(arg1)
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.fn(123)
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
@@ -865,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()
+
+ def test_cancel(self):
+ """Test that cancelling a lookup does not cancel other lookups"""
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await complete_lookup
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ d1 = obj.list_fn([123, 456])
+ d2 = obj.list_fn([123, 456, 789])
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ d1.cancel()
+
+ # `d2` should complete normally.
+ complete_lookup.callback(None)
+ self.failureResultOf(d1, CancelledError)
+ self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
+
+ def test_cancel_logcontexts(self):
+ """Test that cancellation does not break logcontexts.
+
+ * The `CancelledError` must be raised with the correct logcontext.
+ * The inner lookup must not resume with a finished logcontext.
+ * The inner lookup must not restore a finished logcontext when done.
+ """
+ complete_lookup: "Deferred[None]" = Deferred()
+
+ class Cls:
+ inner_context_was_finished = False
+
+ @cached()
+ def fn(self, arg1):
+ pass
+
+ @cachedList(cached_method_name="fn", list_name="args")
+ async def list_fn(self, args):
+ await make_deferred_yieldable(complete_lookup)
+ self.inner_context_was_finished = current_context().finished
+ return {arg: str(arg) for arg in args}
+
+ obj = Cls()
+
+ async def do_lookup():
+ with LoggingContext("c1") as c1:
+ try:
+ await obj.list_fn([123])
+ self.fail("No CancelledError thrown")
+ except CancelledError:
+ self.assertEqual(
+ current_context(),
+ c1,
+ "CancelledError was not raised with the correct logcontext",
+ )
+ # suppress the error and succeed
+
+ d = defer.ensureDeferred(do_lookup())
+ d.cancel()
+
+ complete_lookup.callback(None)
+ self.successResultOf(d)
+ self.assertFalse(
+ obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+ )
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
|