diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..463a737efa 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,20 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from functools import partial
import logging
+from functools import partial
import mock
+
+from twisted.internet import defer, reactor
+
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
-from twisted.internet import defer
from synapse.util.caches import descriptors
+
from tests import unittest
logger = logging.getLogger(__name__)
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return logcontext.make_deferred_yieldable(d)
+
+
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
@@ -59,12 +67,8 @@ class CacheTestCase(unittest.TestCase):
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
- self.assertTrue(
- callback_record[0], "Invalidation callback for key1 not called",
- )
- self.assertTrue(
- callback_record[1], "Invalidation callback for key2 not called",
- )
+ self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+ self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
@@ -160,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(logcontext.LoggingContext.current_context(),
- c1)
+ self.assertEqual(logcontext.LoggingContext.current_context(), c1)
defer.returnValue(r)
def check_result(r):
@@ -171,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
d2.addCallback(check_result)
# let the lookup complete
@@ -195,7 +202,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@@ -205,20 +213,26 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
- yield obj.fn(1)
+ d = obj.fn(1)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(logcontext.LoggingContext.current_context(),
- c1)
+ self.assertEqual(logcontext.LoggingContext.current_context(), c1)
obj = Cls()
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(logcontext.LoggingContext.current_context(),
- logcontext.LoggingContext.sentinel)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
return d1
@@ -259,3 +273,98 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def test_cache(self):
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ assert logcontext.LoggingContext.current_context().request == "c1"
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ assert logcontext.LoggingContext.current_context().request == "c1"
+ defer.returnValue(self.mock(args1, arg2))
+
+ with logcontext.LoggingContext() as c1:
+ c1.request = "c1"
+ obj = Cls()
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ d1 = obj.list_fn([10, 20], 2)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ r = yield d1
+ self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = {30: 'peas'}
+ r = yield obj.list_fn([20, 30], 2)
+ obj.mock.assert_called_once_with([30], 2)
+ self.assertEqual(r, {20: 'chips', 30: 'peas'})
+ obj.mock.reset_mock()
+
+ # all the values should now be cached
+ r = yield obj.fn(10, 2)
+ self.assertEqual(r, 'fish')
+ r = yield obj.fn(20, 2)
+ self.assertEqual(r, 'chips')
+ r = yield obj.fn(30, 2)
+ self.assertEqual(r, 'peas')
+ r = yield obj.list_fn([10, 20, 30], 2)
+ obj.mock.assert_not_called()
+ self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ """Make sure that invalidation callbacks are called."""
+
+ class Cls(object):
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @descriptors.cached()
+ def fn(self, arg1, arg2):
+ pass
+
+ @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+ def list_fn(self, args1, arg2):
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
+ defer.returnValue(self.mock(args1, arg2))
+
+ obj = Cls()
+ invalidate0 = mock.Mock()
+ invalidate1 = mock.Mock()
+
+ # cache miss
+ obj.mock.return_value = {10: 'fish', 20: 'chips'}
+ r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+ obj.mock.assert_called_once_with([10, 20], 2)
+ self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+ obj.mock.reset_mock()
+
+ # cache hit
+ r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+ obj.mock.assert_not_called()
+ self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+ invalidate0.assert_not_called()
+ invalidate1.assert_not_called()
+
+ # now if we invalidate the keys, both invalidations should get called
+ obj.fn.invalidate((10, 2))
+ invalidate0.assert_called_once()
+ invalidate1.assert_called_once()
|