summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-08-03 09:25:15 +0100
committerErik Johnston <erik@matrix.org>2018-08-03 09:25:15 +0100
commitcb298ff623c1ad375084f7687b7f6e546c4c1c1f (patch)
tree81854fc4ae0a45bed487a58f2c40974b683adb92 /tests/util/caches/test_descriptors.py
parentNewsfile (diff)
parentMerge pull request #3645 from matrix-org/michaelkaye/mention_newsfragment (diff)
downloadsynapse-cb298ff623c1ad375084f7687b7f6e546c4c1c1f.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_repl_servlet
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 3)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                defer.returnValue(self.mock(args1, arg2))
+
+        with logcontext.LoggingContext() as c1:
+            c1.request = "c1"
+            obj = Cls()
+            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            d1 = obj.list_fn([10, 20], 2)
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                logcontext.LoggingContext.sentinel,
+            )
+            r = yield d1
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                c1
+            )
+            obj.mock.assert_called_once_with([10, 20], 2)
+            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            obj.mock.reset_mock()
+
+            # a call with different params should call the mock again
+            obj.mock.return_value = {30: 'peas'}
+            r = yield obj.list_fn([20, 30], 2)
+            obj.mock.assert_called_once_with([30], 2)
+            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            obj.mock.reset_mock()
+
+            # all the values should now be cached
+            r = yield obj.fn(10, 2)
+            self.assertEqual(r, 'fish')
+            r = yield obj.fn(20, 2)
+            self.assertEqual(r, 'chips')
+            r = yield obj.fn(30, 2)
+            self.assertEqual(r, 'peas')
+            r = yield obj.list_fn([10, 20, 30], 2)
+            obj.mock.assert_not_called()
+            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        """Make sure that invalidation callbacks are called."""
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                defer.returnValue(self.mock(args1, arg2))
+
+        obj = Cls()
+        invalidate0 = mock.Mock()
+        invalidate1 = mock.Mock()
+
+        # cache miss
+        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+        obj.mock.assert_called_once_with([10, 20], 2)
+        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        obj.mock.reset_mock()
+
+        # cache hit
+        r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+        obj.mock.assert_not_called()
+        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+        invalidate0.assert_not_called()
+        invalidate1.assert_not_called()
+
+        # now if we invalidate the keys, both invalidations should get called
+        obj.fn.invalidate((10, 2))
+        invalidate0.assert_called_once()
+        invalidate1.assert_called_once()