summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py98
1 files changed, 91 insertions, 7 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 7807328e2f..5713870f48 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -27,6 +27,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
+from synapse.util.caches.descriptors import cached
 
 from tests import unittest
 
@@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
         d2 = defer.Deferred()
         cache.set("key2", d2, partial(record_callback, 1))
 
-        # lookup should return the deferreds
-        self.assertIs(cache.get("key1"), d1)
-        self.assertIs(cache.get("key2"), d2)
+        # lookup should return observable deferreds
+        self.assertFalse(cache.get("key1").has_called())
+        self.assertFalse(cache.get("key2").has_called())
 
         # let one of the lookups complete
         d2.callback("result2")
+
+        # for now at least, the cache will return real results rather than an
+        # observabledeferred
         self.assertEqual(cache.get("key2"), "result2")
 
         # now do the invalidation
@@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    def test_cache_with_sync_exception(self):
+        """If the wrapped function throws synchronously, things should continue to work
+        """
+
+        class Cls(object):
+            @cached()
+            def fn(self, arg1):
+                raise SynapseError(100, "mai spoon iz too big!!1")
+
+        obj = Cls()
+
+        # this should fail immediately
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should result in a second exception
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""
@@ -159,7 +185,7 @@ class DescriptorTestCase(unittest.TestCase):
                 def inner_fn():
                     with PreserveLoggingContext():
                         yield complete_lookup
-                    defer.returnValue(1)
+                    return 1
 
                 return inner_fn()
 
@@ -169,7 +195,7 @@ class DescriptorTestCase(unittest.TestCase):
                 c1.name = "c1"
                 r = yield obj.fn(1)
                 self.assertEqual(LoggingContext.current_context(), c1)
-            defer.returnValue(r)
+            return r
 
         def check_result(r):
             self.assertEqual(r, 1)
@@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
 
                 self.assertEqual(LoggingContext.current_context(), c1)
 
+            # the cache should now be empty
+            self.assertEqual(len(obj.fn.cache.cache), 0)
+
         obj = Cls()
 
         # set off a deferred which will do a cache lookup
@@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    def test_cache_iterable(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached(iterable=True)
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+
+        obj.mock.return_value = ["spam", "eggs"]
+        r = obj.fn(1, 2)
+        self.assertEqual(r, ["spam", "eggs"])
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = ["chips"]
+        r = obj.fn(1, 3)
+        self.assertEqual(r, ["chips"])
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        self.assertEqual(len(obj.fn.cache.cache), 3)
+
+        r = obj.fn(1, 2)
+        self.assertEqual(r, ["spam", "eggs"])
+        r = obj.fn(1, 3)
+        self.assertEqual(r, ["chips"])
+        obj.mock.assert_not_called()
+
+    def test_cache_iterable_with_sync_exception(self):
+        """If the wrapped function throws synchronously, things should continue to work
+        """
+
+        class Cls(object):
+            @descriptors.cached(iterable=True)
+            def fn(self, arg1):
+                raise SynapseError(100, "mai spoon iz too big!!1")
+
+        obj = Cls()
+
+        # this should fail immediately
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should result in a second exception
+        d = obj.fn(1)
+        self.failureResultOf(d, SynapseError)
+
 
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
@@ -286,7 +370,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
                 # we want this to behave like an asynchronous function
                 yield run_on_reactor()
                 assert LoggingContext.current_context().request == "c1"
-                defer.returnValue(self.mock(args1, arg2))
+                return self.mock(args1, arg2)
 
         with LoggingContext() as c1:
             c1.request = "c1"
@@ -334,7 +418,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
                 yield run_on_reactor()
-                defer.returnValue(self.mock(args1, arg2))
+                return self.mock(args1, arg2)
 
         obj = Cls()
         invalidate0 = mock.Mock()