summary refs log tree commit diff
path: root/tests/util/caches/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/util/caches/test_descriptors.py')
-rw-r--r--tests/util/caches/test_descriptors.py231
1 files changed, 226 insertions, 5 deletions
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 19741ffcda..48e616ac74 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,7 +17,7 @@ from typing import Set
 from unittest import mock
 
 from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList, lru_cache
 
 from tests import unittest
 from tests.test_utils import get_awaitable_result
@@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    @defer.inlineCallbacks
+    def test_cache_uncached_args(self):
+        """
+        Only the arguments not named in uncached_args should matter to the cache
+
+        Note that this is identical to test_cache_num_args, but provides the
+        arguments differently.
+        """
+
+        class Cls:
+            # Note that it is important that this is not the last argument to
+            # test behaviour of skipping arguments properly.
+            @descriptors.cached(uncached_args=("arg2",))
+            def fn(self, arg1, arg2, arg3):
+                return self.mock(arg1, arg2, arg3)
+
+            def __init__(self):
+                self.mock = mock.Mock()
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(2, 3, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(2, 3, 4)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached; we should be able to vary
+        # the second argument and still get the cached result.
+        r = yield obj.fn(1, 4, 3)
+        self.assertEqual(r, "fish")
+        r = yield obj.fn(2, 5, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
+
+    @defer.inlineCallbacks
+    def test_cache_kwargs(self):
+        """Test that keyword arguments are treated properly"""
+
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, kwarg1=2):
+                return self.mock(arg1, kwarg1=kwarg1)
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, kwarg1=2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(1, kwarg1=3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, kwarg1=3)
+        obj.mock.reset_mock()
+
+        # the values should now be cached.
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        # We should be able to not provide kwarg1 and get the cached value back.
+        r = yield obj.fn(1)
+        self.assertEqual(r, "fish")
+        # Keyword arguments can be in any order.
+        r = yield obj.fn(kwarg1=2, arg1=1)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_not_called()
+
     def test_cache_with_sync_exception(self):
         """If the wrapped function throws synchronously, things should continue to work"""
 
@@ -415,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
         obj.invalidate()
         top_invalidate.assert_called_once()
 
+    def test_cancel(self):
+        """Test that cancelling a lookup does not cancel other lookups"""
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            @cached()
+            async def fn(self, arg1):
+                await complete_lookup
+                return str(arg1)
+
+        obj = Cls()
+
+        d1 = obj.fn(123)
+        d2 = obj.fn(123)
+        self.assertFalse(d1.called)
+        self.assertFalse(d2.called)
+
+        # Cancel `d1`, which is the lookup that caused `fn` to run.
+        d1.cancel()
+
+        # `d2` should complete normally.
+        complete_lookup.callback(None)
+        self.failureResultOf(d1, CancelledError)
+        self.assertEqual(d2.result, "123")
+
+    def test_cancel_logcontexts(self):
+        """Test that cancellation does not break logcontexts.
+
+        * The `CancelledError` must be raised with the correct logcontext.
+        * The inner lookup must not resume with a finished logcontext.
+        * The inner lookup must not restore a finished logcontext when done.
+        """
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            inner_context_was_finished = False
+
+            @cached()
+            async def fn(self, arg1):
+                await make_deferred_yieldable(complete_lookup)
+                self.inner_context_was_finished = current_context().finished
+                return str(arg1)
+
+        obj = Cls()
+
+        async def do_lookup():
+            with LoggingContext("c1") as c1:
+                try:
+                    await obj.fn(123)
+                    self.fail("No CancelledError thrown")
+                except CancelledError:
+                    self.assertEqual(
+                        current_context(),
+                        c1,
+                        "CancelledError was not raised with the correct logcontext",
+                    )
+                    # suppress the error and succeed
+
+        d = defer.ensureDeferred(do_lookup())
+        d.cancel()
+
+        complete_lookup.callback(None)
+        self.successResultOf(d)
+        self.assertFalse(
+            obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+        )
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
 
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     """More tests for @cached
@@ -656,7 +802,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 assert current_context().name == "c1"
                 # we want this to behave like an asynchronous function
@@ -715,7 +861,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             def list_fn(self, args1) -> "Deferred[dict]":
                 return self.mock(args1)
 
@@ -758,7 +904,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
@@ -787,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         obj.fn.invalidate((10, 2))
         invalidate0.assert_called_once()
         invalidate1.assert_called_once()
+
+    def test_cancel(self):
+        """Test that cancelling a lookup does not cancel other lookups"""
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            @cached()
+            def fn(self, arg1):
+                pass
+
+            @cachedList(cached_method_name="fn", list_name="args")
+            async def list_fn(self, args):
+                await complete_lookup
+                return {arg: str(arg) for arg in args}
+
+        obj = Cls()
+
+        d1 = obj.list_fn([123, 456])
+        d2 = obj.list_fn([123, 456, 789])
+        self.assertFalse(d1.called)
+        self.assertFalse(d2.called)
+
+        d1.cancel()
+
+        # `d2` should complete normally.
+        complete_lookup.callback(None)
+        self.failureResultOf(d1, CancelledError)
+        self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
+
+    def test_cancel_logcontexts(self):
+        """Test that cancellation does not break logcontexts.
+
+        * The `CancelledError` must be raised with the correct logcontext.
+        * The inner lookup must not resume with a finished logcontext.
+        * The inner lookup must not restore a finished logcontext when done.
+        """
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            inner_context_was_finished = False
+
+            @cached()
+            def fn(self, arg1):
+                pass
+
+            @cachedList(cached_method_name="fn", list_name="args")
+            async def list_fn(self, args):
+                await make_deferred_yieldable(complete_lookup)
+                self.inner_context_was_finished = current_context().finished
+                return {arg: str(arg) for arg in args}
+
+        obj = Cls()
+
+        async def do_lookup():
+            with LoggingContext("c1") as c1:
+                try:
+                    await obj.list_fn([123])
+                    self.fail("No CancelledError thrown")
+                except CancelledError:
+                    self.assertEqual(
+                        current_context(),
+                        c1,
+                        "CancelledError was not raised with the correct logcontext",
+                    )
+                    # suppress the error and succeed
+
+        d = defer.ensureDeferred(do_lookup())
+        d.cancel()
+
+        complete_lookup.callback(None)
+        self.successResultOf(d)
+        self.assertFalse(
+            obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+        )
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)