summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
 
 from tests import unittest
+from tests.test_utils import get_awaitable_result
 
 logger = logging.getLogger(__name__)
 
 
+class LruCacheDecoratorTestCase(unittest.TestCase):
+    def test_base(self):
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @lru_cache()
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
+
+
 def run_on_reactor():
     d = defer.Deferred()
     reactor.callLater(0, d.callback, 0)
@@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_invalidate_cascade(self):
+        """Invalidations should cascade up through cache contexts"""
+
+        class Cls:
+            @cached(cache_context=True)
+            async def func1(self, key, cache_context):
+                return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+            @cached(cache_context=True)
+            async def func2(self, key, cache_context):
+                return self.func3(key, on_invalidate=cache_context.invalidate)
+
+            @lru_cache(cache_context=True)
+            def func3(self, key, cache_context):
+                self.invalidate = cache_context.invalidate
+                return 42
+
+        obj = Cls()
+
+        top_invalidate = mock.Mock()
+        r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+        self.assertEqual(r, 42)
+        obj.invalidate()
+        top_invalidate.assert_called_once()
+
 
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     """More tests for @cached