summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12189.misc1
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/util/caches/descriptors.py74
-rw-r--r--tests/util/caches/test_descriptors.py84
4 files changed, 142 insertions, 21 deletions
diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc
new file mode 100644
index 0000000000..015e808e63
--- /dev/null
+++ b/changelog.d/12189.misc
@@ -0,0 +1 @@
+Support skipping some arguments when generating cache keys.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 26784f755e..59454a47df 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
         return {eid for ((_rid, eid), have_event) in res.items() if have_event}
 
-    @cachedList("have_seen_event", "keys")
+    @cachedList(cached_method_name="have_seen_event", list_name="keys")
     async def _have_seen_events_dict(
         self, keys: Iterable[Tuple[str, str]]
     ) -> Dict[Tuple[str, str], bool]:
@@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore):
             get_event_id_for_timestamp_txn,
         )
 
-    @cachedList("is_partial_state_event", list_name="event_ids")
+    @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
     async def get_partial_state_events(
         self, event_ids: Collection[str]
     ) -> Dict[str, bool]:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1cdead02f1..c3c5c16db9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -20,6 +20,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Generic,
     Hashable,
@@ -69,6 +70,7 @@ class _CacheDescriptorBase:
         self,
         orig: Callable[..., Any],
         num_args: Optional[int],
+        uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
     ):
         self.orig = orig
@@ -76,6 +78,13 @@ class _CacheDescriptorBase:
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
 
+        # There's no reason that keyword-only arguments couldn't be supported,
+        # but right now they're buggy so do not allow them.
+        if arg_spec.kwonlyargs:
+            raise ValueError(
+                "_CacheDescriptorBase does not support keyword-only arguments."
+            )
+
         if "cache_context" in all_args:
             if not cache_context:
                 raise ValueError(
@@ -88,6 +97,9 @@ class _CacheDescriptorBase:
                 " named `cache_context`"
             )
 
+        if num_args is not None and uncached_args is not None:
+            raise ValueError("Cannot provide both num_args and uncached_args")
+
         if num_args is None:
             num_args = len(all_args) - 1
             if cache_context:
@@ -105,6 +117,12 @@ class _CacheDescriptorBase:
         # list of the names of the args used as the cache key
         self.arg_names = all_args[1 : num_args + 1]
 
+        # If there are args to not cache on, filter them out (and fix the size of num_args).
+        if uncached_args is not None:
+            include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
+        else:
+            include_arg_in_cache_key = [True] * len(self.arg_names)
+
         # self.arg_defaults is a map of arg name to its default value for each
         # argument that has a default value
         if arg_spec.defaults:
@@ -119,8 +137,8 @@ class _CacheDescriptorBase:
 
         self.add_cache_context = cache_context
 
-        self.cache_key_builder = get_cache_key_builder(
-            self.arg_names, self.arg_defaults
+        self.cache_key_builder = _get_cache_key_builder(
+            self.arg_names, include_arg_in_cache_key, self.arg_defaults
         )
 
 
@@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):
 
 
 def lru_cache(
-    max_entries: int = 1000,
-    cache_context: bool = False,
+    *, max_entries: int = 1000, cache_context: bool = False
 ) -> Callable[[F], _LruCachedFunction[F]]:
     """A method decorator that applies a memoizing cache around the function.
 
@@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         max_entries: int = 1000,
         cache_context: bool = False,
     ):
-        super().__init__(orig, num_args=None, cache_context=cache_context)
+        super().__init__(
+            orig, num_args=None, uncached_args=None, cache_context=cache_context
+        )
         self.max_entries = max_entries
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
@@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         num_args: number of positional arguments (excluding ``self`` and
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
+        uncached_args: a list of argument names to not use as the cache key.
+            (``self`` and ``cache_context`` are always ignored.) Cannot be used
+            with num_args.
         tree:
         cache_context:
         iterable:
@@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         orig: Callable[..., Any],
         max_entries: int = 1000,
         num_args: Optional[int] = None,
+        uncached_args: Optional[Collection[str]] = None,
         tree: bool = False,
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
     ):
-        super().__init__(orig, num_args=num_args, cache_context=cache_context)
+        super().__init__(
+            orig,
+            num_args=num_args,
+            uncached_args=uncached_args,
+            cache_context=cache_context,
+        )
 
         if tree and self.num_args < 2:
             raise RuntimeError(
@@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args)
+        super().__init__(orig, num_args=num_args, uncached_args=None)
 
         self.list_name = list_name
 
@@ -530,8 +558,10 @@ class _CacheContext:
 
 
 def cached(
+    *,
     max_entries: int = 1000,
     num_args: Optional[int] = None,
+    uncached_args: Optional[Collection[str]] = None,
     tree: bool = False,
     cache_context: bool = False,
     iterable: bool = False,
@@ -541,6 +571,7 @@ def cached(
         orig,
         max_entries=max_entries,
         num_args=num_args,
+        uncached_args=uncached_args,
         tree=tree,
         cache_context=cache_context,
         iterable=iterable,
@@ -551,7 +582,7 @@ def cached(
 
 
 def cachedList(
-    cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
@@ -590,13 +621,16 @@ def cachedList(
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
-def get_cache_key_builder(
-    param_names: Sequence[str], param_defaults: Mapping[str, Any]
+def _get_cache_key_builder(
+    param_names: Sequence[str],
+    include_params: Sequence[bool],
+    param_defaults: Mapping[str, Any],
 ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
     """Construct a function which will build cache keys suitable for a cached function
 
     Args:
         param_names: list of formal parameter names for the cached function
+        include_params: list of bools of whether to include the parameter name in the cache key
         param_defaults: a mapping from parameter name to default value for that param
 
     Returns:
@@ -608,6 +642,7 @@ def get_cache_key_builder(
 
     if len(param_names) == 1:
         nm = param_names[0]
+        assert include_params[0] is True
 
         def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
             if nm in kwargs:
@@ -620,13 +655,18 @@ def get_cache_key_builder(
     else:
 
         def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
-            return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+            return tuple(
+                _get_cache_key_gen(
+                    param_names, include_params, param_defaults, args, kwargs
+                )
+            )
 
     return get_cache_key
 
 
 def _get_cache_key_gen(
     param_names: Iterable[str],
+    include_params: Iterable[bool],
     param_defaults: Mapping[str, Any],
     args: Sequence[Any],
     kwargs: Mapping[str, Any],
@@ -637,16 +677,18 @@ def _get_cache_key_gen(
     This is essentially the same operation as `inspect.getcallargs`, but optimised so
     that we don't need to inspect the target function for each call.
     """
-
     # We loop through each arg name, looking up if its in the `kwargs`,
     # otherwise using the next argument in `args`. If there are no more
     # args then we try looking the arg name up in the defaults.
     pos = 0
-    for nm in param_names:
+    for nm, inc in zip(param_names, include_params):
         if nm in kwargs:
-            yield kwargs[nm]
+            if inc:
+                yield kwargs[nm]
         elif pos < len(args):
-            yield args[pos]
+            if inc:
+                yield args[pos]
             pos += 1
         else:
-            yield param_defaults[nm]
+            if inc:
+                yield param_defaults[nm]
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 19741ffcda..6a4b17527a 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    @defer.inlineCallbacks
+    def test_cache_uncached_args(self):
+        """
+        Only the arguments not named in uncached_args should matter to the cache
+
+        Note that this is identical to test_cache_num_args, but provides the
+        arguments differently.
+        """
+
+        class Cls:
+            # Note that it is important that this is not the last argument to
+            # test behaviour of skipping arguments properly.
+            @descriptors.cached(uncached_args=("arg2",))
+            def fn(self, arg1, arg2, arg3):
+                return self.mock(arg1, arg2, arg3)
+
+            def __init__(self):
+                self.mock = mock.Mock()
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(2, 3, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(2, 3, 4)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached; we should be able to vary
+        # the second argument and still get the cached result.
+        r = yield obj.fn(1, 4, 3)
+        self.assertEqual(r, "fish")
+        r = yield obj.fn(2, 5, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
+
+    @defer.inlineCallbacks
+    def test_cache_kwargs(self):
+        """Test that keyword arguments are treated properly"""
+
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, kwarg1=2):
+                return self.mock(arg1, kwarg1=kwarg1)
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, kwarg1=2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(1, kwarg1=3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, kwarg1=3)
+        obj.mock.reset_mock()
+
+        # the values should now be cached.
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        # We should be able to not provide kwarg1 and get the cached value back.
+        r = yield obj.fn(1)
+        self.assertEqual(r, "fish")
+        # Keyword arguments can be in any order.
+        r = yield obj.fn(kwarg1=2, arg1=1)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_not_called()
+
     def test_cache_with_sync_exception(self):
         """If the wrapped function throws synchronously, things should continue to work"""
 
@@ -656,7 +734,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 assert current_context().name == "c1"
                 # we want this to behave like an asynchronous function
@@ -715,7 +793,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             def list_fn(self, args1) -> "Deferred[dict]":
                 return self.mock(args1)
 
@@ -758,7 +836,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()