diff options
-rw-r--r-- | changelog.d/11246.misc | 1 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 43 |
2 files changed, 44 insertions, 0 deletions
diff --git a/changelog.d/11246.misc b/changelog.d/11246.misc new file mode 100644 index 0000000000..e5e912c1b0 --- /dev/null +++ b/changelog.d/11246.misc @@ -0,0 +1 @@ +Add an additional test for the `cachedList` method decorator. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 39947a166b..ced3efd93f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -17,6 +17,7 @@ from typing import Set from unittest import mock from twisted.internet import defer, reactor +from twisted.internet.defer import Deferred from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -703,6 +704,48 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.assert_called_once_with((40,), 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) + def test_concurrent_lookups(self): + """All concurrent lookups should get the same result""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1): + pass + + @descriptors.cachedList("fn", "args1") + def list_fn(self, args1) -> "Deferred[dict]": + return self.mock(args1) + + obj = Cls() + deferred_result = Deferred() + obj.mock.return_value = deferred_result + + # start off several concurrent lookups of the same key + d1 = obj.list_fn([10]) + d2 = obj.list_fn([10]) + d3 = obj.list_fn([10]) + + # the mock should have been called exactly once + obj.mock.assert_called_once_with((10,)) + obj.mock.reset_mock() + + # ... and none of the calls should yet be complete + self.assertFalse(d1.called) + self.assertFalse(d2.called) + self.assertFalse(d3.called) + + # complete the lookup. @cachedList functions need to complete with a map + # of input->result + deferred_result.callback({10: "peas"}) + + # ... which should give the right result to all the callers + self.assertEqual(self.successResultOf(d1), {10: "peas"}) + self.assertEqual(self.successResultOf(d2), {10: "peas"}) + self.assertEqual(self.successResultOf(d3), {10: "peas"}) + @defer.inlineCallbacks def test_invalidate(self): """Make sure that invalidation callbacks are called.""" |