diff options
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r-- | tests/util/caches/test_descriptors.py | 151 |
1 files changed, 130 insertions, 21 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2516fe40f4..463a737efa 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,20 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging +from functools import partial import mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext -from twisted.internet import defer from synapse.util.caches import descriptors + from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + class CacheTestCase(unittest.TestCase): def test_invalidate_all(self): cache = descriptors.Cache("testcache") @@ -59,12 +67,8 @@ class CacheTestCase(unittest.TestCase): self.assertIsNone(cache.get("key2", None)) # both callbacks should have been callbacked - self.assertTrue( - callback_record[0], "Invalidation callback for key1 not called", - ) - self.assertTrue( - callback_record[1], "Invalidation callback for key2 not called", - ) + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") # letting the other lookup complete should do nothing d1.callback("result1") @@ -160,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) defer.returnValue(r) def check_result(r): @@ -171,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d2.addCallback(check_result) # let the lookup complete @@ -195,7 +202,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -205,20 +213,26 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj = Cls() # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) return d1 @@ -259,3 +273,98 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 3) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + +class CachedListDescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + assert logcontext.LoggingContext.current_context().request == "c1" + # we want this to behave like an asynchronous function + yield run_on_reactor() + assert logcontext.LoggingContext.current_context().request == "c1" + defer.returnValue(self.mock(args1, arg2)) + + with logcontext.LoggingContext() as c1: + c1.request = "c1" + obj = Cls() + obj.mock.return_value = {10: 'fish', 20: 'chips'} + d1 = obj.list_fn([10, 20], 2) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + r = yield d1 + self.assertEqual(logcontext.LoggingContext.current_context(), c1) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = {30: 'peas'} + r = yield obj.list_fn([20, 30], 2) + obj.mock.assert_called_once_with([30], 2) + self.assertEqual(r, {20: 'chips', 30: 'peas'}) + obj.mock.reset_mock() + + # all the values should now be cached + r = yield obj.fn(10, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(20, 2) + self.assertEqual(r, 'chips') + r = yield obj.fn(30, 2) + self.assertEqual(r, 'peas') + r = yield obj.list_fn([10, 20, 30], 2) + obj.mock.assert_not_called() + self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + + @defer.inlineCallbacks + def test_invalidate(self): + """Make sure that invalidation callbacks are called.""" + + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + pass + + @descriptors.cachedList("fn", "args1", inlineCallbacks=True) + def list_fn(self, args1, arg2): + # we want this to behave like an asynchronous function + yield run_on_reactor() + defer.returnValue(self.mock(args1, arg2)) + + obj = Cls() + invalidate0 = mock.Mock() + invalidate1 = mock.Mock() + + # cache miss + obj.mock.return_value = {10: 'fish', 20: 'chips'} + r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) + obj.mock.assert_called_once_with([10, 20], 2) + self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + obj.mock.reset_mock() + + # cache hit + r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) + obj.mock.assert_not_called() + self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + + invalidate0.assert_not_called() + invalidate1.assert_not_called() + + # now if we invalidate the keys, both invalidations should get called + obj.fn.invalidate((10, 2)) + invalidate0.assert_called_once() + invalidate1.assert_called_once() |