summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py151
1 files changed, 130 insertions, 21 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..463a737efa 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,20 +13,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from functools import partial
 import logging
+from functools import partial
 
 import mock
+
+from twisted.internet import defer, reactor
+
 from synapse.api.errors import SynapseError
-from synapse.util import async
 from synapse.util import logcontext
-from twisted.internet import defer
 from synapse.util.caches import descriptors
+
 from tests import unittest
 
 logger = logging.getLogger(__name__)
 
 
+def run_on_reactor():
+    d = defer.Deferred()
+    reactor.callLater(0, d.callback, 0)
+    return logcontext.make_deferred_yieldable(d)
+
+
 class CacheTestCase(unittest.TestCase):
     def test_invalidate_all(self):
         cache = descriptors.Cache("testcache")
@@ -59,12 +67,8 @@ class CacheTestCase(unittest.TestCase):
         self.assertIsNone(cache.get("key2", None))
 
         # both callbacks should have been callbacked
-        self.assertTrue(
-            callback_record[0], "Invalidation callback for key1 not called",
-        )
-        self.assertTrue(
-            callback_record[1], "Invalidation callback for key2 not called",
-        )
+        self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+        self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
 
         # letting the other lookup complete should do nothing
         d1.callback("result1")
@@ -160,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
             with logcontext.LoggingContext() as c1:
                 c1.name = "c1"
                 r = yield obj.fn(1)
-                self.assertEqual(logcontext.LoggingContext.current_context(),
-                                 c1)
+                self.assertEqual(logcontext.LoggingContext.current_context(), c1)
             defer.returnValue(r)
 
         def check_result(r):
@@ -171,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase):
 
         # set off a deferred which will do a cache lookup
         d1 = do_lookup()
-        self.assertEqual(logcontext.LoggingContext.current_context(),
-                         logcontext.LoggingContext.sentinel)
+        self.assertEqual(
+            logcontext.LoggingContext.current_context(),
+            logcontext.LoggingContext.sentinel,
+        )
         d1.addCallback(check_result)
 
         # and another
         d2 = do_lookup()
-        self.assertEqual(logcontext.LoggingContext.current_context(),
-                         logcontext.LoggingContext.sentinel)
+        self.assertEqual(
+            logcontext.LoggingContext.current_context(),
+            logcontext.LoggingContext.sentinel,
+        )
         d2.addCallback(check_result)
 
         # let the lookup complete
@@ -195,7 +202,8 @@ class DescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 @defer.inlineCallbacks
                 def inner_fn():
-                    yield async.run_on_reactor()
+                    # we want this to behave like an asynchronous function
+                    yield run_on_reactor()
                     raise SynapseError(400, "blah")
 
                 return inner_fn()
@@ -205,20 +213,26 @@ class DescriptorTestCase(unittest.TestCase):
             with logcontext.LoggingContext() as c1:
                 c1.name = "c1"
                 try:
-                    yield obj.fn(1)
+                    d = obj.fn(1)
+                    self.assertEqual(
+                        logcontext.LoggingContext.current_context(),
+                        logcontext.LoggingContext.sentinel,
+                    )
+                    yield d
                     self.fail("No exception thrown")
                 except SynapseError:
                     pass
 
-                self.assertEqual(logcontext.LoggingContext.current_context(),
-                                 c1)
+                self.assertEqual(logcontext.LoggingContext.current_context(), c1)
 
         obj = Cls()
 
         # set off a deferred which will do a cache lookup
         d1 = do_lookup()
-        self.assertEqual(logcontext.LoggingContext.current_context(),
-                         logcontext.LoggingContext.sentinel)
+        self.assertEqual(
+            logcontext.LoggingContext.current_context(),
+            logcontext.LoggingContext.sentinel,
+        )
 
         return d1
 
@@ -259,3 +273,98 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 3)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                assert logcontext.LoggingContext.current_context().request == "c1"
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                assert logcontext.LoggingContext.current_context().request == "c1"
+                defer.returnValue(self.mock(args1, arg2))
+
+        with logcontext.LoggingContext() as c1:
+            c1.request = "c1"
+            obj = Cls()
+            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            d1 = obj.list_fn([10, 20], 2)
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                logcontext.LoggingContext.sentinel,
+            )
+            r = yield d1
+            self.assertEqual(logcontext.LoggingContext.current_context(), c1)
+            obj.mock.assert_called_once_with([10, 20], 2)
+            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            obj.mock.reset_mock()
+
+            # a call with different params should call the mock again
+            obj.mock.return_value = {30: 'peas'}
+            r = yield obj.list_fn([20, 30], 2)
+            obj.mock.assert_called_once_with([30], 2)
+            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            obj.mock.reset_mock()
+
+            # all the values should now be cached
+            r = yield obj.fn(10, 2)
+            self.assertEqual(r, 'fish')
+            r = yield obj.fn(20, 2)
+            self.assertEqual(r, 'chips')
+            r = yield obj.fn(30, 2)
+            self.assertEqual(r, 'peas')
+            r = yield obj.list_fn([10, 20, 30], 2)
+            obj.mock.assert_not_called()
+            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        """Make sure that invalidation callbacks are called."""
+
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                defer.returnValue(self.mock(args1, arg2))
+
+        obj = Cls()
+        invalidate0 = mock.Mock()
+        invalidate1 = mock.Mock()
+
+        # cache miss
+        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+        obj.mock.assert_called_once_with([10, 20], 2)
+        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        obj.mock.reset_mock()
+
+        # cache hit
+        r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+        obj.mock.assert_not_called()
+        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+        invalidate0.assert_not_called()
+        invalidate1.assert_not_called()
+
+        # now if we invalidate the keys, both invalidations should get called
+        obj.fn.invalidate((10, 2))
+        invalidate0.assert_called_once()
+        invalidate1.assert_called_once()