diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1a8c99d387..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -528,7 +528,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
- cache.set(key, observable)
+ cache.set(key, observable, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 1ac967b63e..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -275,7 +275,6 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
-@unittest.DEBUG
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
@@ -334,3 +333,44 @@ class CachedListDescriptorTestCase(unittest.TestCase):
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ """Make sure that invalidation callbacks are called."""
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ defer.returnValue(self.mock(args1, arg2))
+
+ obj = Cls()
+ invalidate0 = mock.Mock()
+ invalidate1 = mock.Mock()
+
+ # cache miss
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # cache hit
+ r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+ obj.mock.assert_not_called()
+ self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+ invalidate0.assert_not_called()
+ invalidate1.assert_not_called()
+
+ # now if we invalidate the keys, both invalidations should get called
+ obj.fn.invalidate((10, 2))
+ invalidate0.assert_called_once()
+ invalidate1.assert_called_once()
|