summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index fc2663c02d..2ad08f541b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Set
 
 import mock
 
@@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_cache_with_async_exception(self):
+        """The wrapped function returns a failure
+        """
+
+        class Cls:
+            result = None
+            call_count = 0
+
+            @cached()
+            def fn(self, arg1):
+                self.call_count += 1
+                return self.result
+
+        obj = Cls()
+        callbacks = set()  # type: Set[str]
+
+        # set off an asynchronous request
+        obj.result = origin_d = defer.Deferred()
+
+        d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+        self.assertFalse(d1.called)
+
+        # a second request should also return a deferred, but should not call the
+        # function itself.
+        d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+        self.assertFalse(d2.called)
+        self.assertEqual(obj.call_count, 1)
+
+        # no callbacks yet
+        self.assertEqual(callbacks, set())
+
+        # the original request fails
+        e = Exception("bzz")
+        origin_d.errback(e)
+
+        # ... which should cause the lookups to fail similarly
+        self.assertIs(self.failureResultOf(d1, Exception).value, e)
+        self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+        # ... and the callbacks to have been, uh, called.
+        self.assertEqual(callbacks, {"d1", "d2"})
+
+        # ... leaving the cache empty
+        self.assertEqual(len(obj.fn.cache.cache), 0)
+
+        # and a second call should work as normal
+        obj.result = defer.succeed(100)
+        d3 = obj.fn(1)
+        self.assertEqual(self.successResultOf(d3), 100)
+        self.assertEqual(obj.call_count, 2)
+
     def test_cache_logcontexts(self):
         """Check that logcontexts are set and restored correctly when
         using the cache."""