summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8595.misc1
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py37
-rw-r--r--synapse/util/caches/descriptors.py235
-rw-r--r--tests/util/caches/test_descriptors.py60
4 files changed, 272 insertions, 61 deletions
diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc
new file mode 100644
index 0000000000..24fab65cda
--- /dev/null
+++ b/changelog.d/8595.misc
@@ -0,0 +1 @@
+Implement and use an @lru_cache decorator.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 
+import attr
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
             dict of user_id -> push_rules
         """
         room_id = event.room_id
-        rules_for_room = await self._get_rules_for_room(room_id)
+        rules_for_room = self._get_rules_for_room(room_id)
 
         rules_by_user = await rules_for_room.get_rules(event, context)
 
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
 
         return rules_by_user
 
-    @cached()
+    @lru_cache()
     def _get_rules_for_room(self, room_id):
         """Get the current RulesForRoom object for the given room id
 
@@ -275,12 +276,14 @@ class RulesForRoom:
     the entire cache for the room.
     """
 
-    def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+    def __init__(
+        self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+    ):
         """
         Args:
             hs (HomeServer)
             room_id (str)
-            rules_for_room_cache(Cache): The cache object that caches these
+            rules_for_room_cache: The cache object that caches these
                 RoomsForUser objects.
             room_push_rule_cache_metrics (CacheMetric)
         """
@@ -489,13 +492,21 @@ class RulesForRoom:
             self.state_group = state_group
 
 
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
-    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
-    # which namedtuple does for us (i.e. two _CacheContext are the same if
-    # their caches and keys match). This is important in particular to
-    # dedupe when we add callbacks to lru cache nodes, otherwise the number
-    # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+    # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+    # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+    # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+    # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+    # same `cache` and `room_id`.
+    #
+    # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+    # set `frozen=True`.
+
+    cache = attr.ib(type=LruCache)
+    room_id = attr.ib(type=str)
+
     def __call__(self):
-        rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+        rules = self.cache.get(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import functools
 import inspect
 import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
 from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
 
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
 
 
 class _CacheDescriptorBase:
-    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
 
         self.add_cache_context = cache_context
 
+        self.cache_key_builder = get_cache_key_builder(
+            self.arg_names, self.arg_defaults
+        )
+
+
+class _LruCachedFunction(Generic[F]):
+    cache = None  # type: LruCache[CacheKey, Any]
+    __call__ = None  # type: F
+
+
+def lru_cache(
+    max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+    """A method decorator that applies a memoizing cache around the function.
+
+    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+    that the signature is slightly different.
+
+    The main differences with functools.lru_cache are:
+        (a) the size of the cache can be controlled via the cache_factor mechanism
+        (b) the wrapped function can request a "cache_context" which provides a
+            callback mechanism to indicate that the result is no longer valid
+        (c) prometheus metrics are exposed automatically.
+
+    The function should take zero or more arguments, which are used as the key for the
+    cache. Single-argument functions use that argument as the cache key; otherwise the
+    arguments are built into a tuple.
+
+    Cached functions can be "chained" (i.e. a cached function can call other cached
+    functions and get appropriately invalidated when they called caches are
+    invalidated) by adding a special "cache_context" argument to the function
+    and passing that as a kwarg to all caches called. For example:
+
+        @lru_cache(cache_context=True)
+        def foo(self, key, cache_context):
+            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+            return r1 + r2
+
+    The wrapped function also has a 'cache' property which offers direct access to the
+    underlying LruCache.
+    """
+
+    def func(orig: F) -> _LruCachedFunction[F]:
+        desc = LruCacheDescriptor(
+            orig, max_entries=max_entries, cache_context=cache_context,
+        )
+        return cast(_LruCachedFunction[F], desc)
+
+    return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+    """Helper for @lru_cache"""
+
+    class _Sentinel(enum.Enum):
+        sentinel = object()
+
+    def __init__(
+        self, orig, max_entries: int = 1000, cache_context: bool = False,
+    ):
+        super().__init__(orig, num_args=None, cache_context=cache_context)
+        self.max_entries = max_entries
+
+    def __get__(self, obj, owner):
+        cache = LruCache(
+            cache_name=self.orig.__name__, max_size=self.max_entries,
+        )  # type: LruCache[CacheKey, Any]
+
+        get_cache_key = self.cache_key_builder
+        sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+        @functools.wraps(self.orig)
+        def _wrapped(*args, **kwargs):
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+            callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+            cache_key = get_cache_key(args, kwargs)
 
-class CacheDescriptor(_CacheDescriptorBase):
+            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+            if ret != sentinel:
+                return ret
+
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+            ret2 = self.orig(obj, *args, **kwargs)
+            cache.set(cache_key, ret2, callbacks=callbacks)
+
+            return ret2
+
+        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped.cache = cache
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
         cache_context=False,
         iterable=False,
     ):
-
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )  # type: DeferredCache[CacheKey, Any]
 
-        def get_cache_key_gen(args, kwargs):
-            """Given some args/kwargs return a generator that resolves into
-            the cache_key.
-
-            We loop through each arg name, looking up if its in the `kwargs`,
-            otherwise using the next argument in `args`. If there are no more
-            args then we try looking the arg name up in the defaults
-            """
-            pos = 0
-            for nm in self.arg_names:
-                if nm in kwargs:
-                    yield kwargs[nm]
-                elif pos < len(args):
-                    yield args[pos]
-                    pos += 1
-                else:
-                    yield self.arg_defaults[nm]
-
-        # By default our cache key is a tuple, but if there is only one item
-        # then don't bother wrapping in a tuple.  This is to save memory.
-        if self.num_args == 1:
-            nm = self.arg_names[0]
-
-            def get_cache_key(args, kwargs):
-                if nm in kwargs:
-                    return kwargs[nm]
-                elif len(args):
-                    return args[0]
-                else:
-                    return self.arg_defaults[nm]
-
-        else:
-
-            def get_cache_key(args, kwargs):
-                return tuple(get_cache_key_gen(args, kwargs))
+        get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
         def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_all = cache.invalidate_all
             wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         return wrapped
 
 
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
     on a lower level.
     """
 
+    Cache = Union[DeferredCache, LruCache]
+
     _cache_context_objects = (
         WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
 
-    def __init__(self, cache, cache_key):  # type: (DeferredCache, CacheKey) -> None
+    def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
@@ -396,8 +475,8 @@ class _CacheContext:
 
     @classmethod
     def get_instance(
-        cls, cache, cache_key
-    ):  # type: (DeferredCache, CacheKey) -> _CacheContext
+        cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+    ) -> "_CacheContext":
         """Returns an instance constructed with the given arguments.
 
         A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
 ) -> Callable[[F], _CachedFunction[F]]:
-    func = lambda orig: CacheDescriptor(
+    func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    func = lambda orig: CacheListDescriptor(
+    func = lambda orig: DeferredCacheListDescriptor(
         orig,
         cached_method_name=cached_method_name,
         list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+    param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+    """Construct a function which will build cache keys suitable for a cached function
+
+    Args:
+        param_names: list of formal parameter names for the cached function
+        param_defaults: a mapping from parameter name to default value for that param
+
+    Returns:
+        A function which will take an (args, kwargs) pair and return a cache key
+    """
+
+    # By default our cache key is a tuple, but if there is only one item
+    # then don't bother wrapping in a tuple.  This is to save memory.
+
+    if len(param_names) == 1:
+        nm = param_names[0]
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            if nm in kwargs:
+                return kwargs[nm]
+            elif len(args):
+                return args[0]
+            else:
+                return param_defaults[nm]
+
+    else:
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+    return get_cache_key
+
+
+def _get_cache_key_gen(
+    param_names: Iterable[str],
+    param_defaults: Mapping[str, Any],
+    args: Sequence[Any],
+    kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+    """Given some args/kwargs return a generator that resolves into
+    the cache_key.
+
+    This is essentially the same operation as `inspect.getcallargs`, but optimised so
+    that we don't need to inspect the target function for each call.
+    """
+
+    # We loop through each arg name, looking up if its in the `kwargs`,
+    # otherwise using the next argument in `args`. If there are no more
+    # args then we try looking the arg name up in the defaults.
+    pos = 0
+    for nm in param_names:
+        if nm in kwargs:
+            yield kwargs[nm]
+        elif pos < len(args):
+            yield args[pos]
+            pos += 1
+        else:
+            yield param_defaults[nm]
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
 
 from tests import unittest
+from tests.test_utils import get_awaitable_result
 
 logger = logging.getLogger(__name__)
 
 
+class LruCacheDecoratorTestCase(unittest.TestCase):
+    def test_base(self):
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @lru_cache()
+            def fn(self, arg1, arg2):
+                return self.mock(arg1, arg2)
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = obj.fn(1, 2)
+        self.assertEqual(r, "fish")
+        r = obj.fn(1, 3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
+
+
 def run_on_reactor():
     d = defer.Deferred()
     reactor.callLater(0, d.callback, 0)
@@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
+    def test_invalidate_cascade(self):
+        """Invalidations should cascade up through cache contexts"""
+
+        class Cls:
+            @cached(cache_context=True)
+            async def func1(self, key, cache_context):
+                return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+            @cached(cache_context=True)
+            async def func2(self, key, cache_context):
+                return self.func3(key, on_invalidate=cache_context.invalidate)
+
+            @lru_cache(cache_context=True)
+            def func3(self, key, cache_context):
+                self.invalidate = cache_context.invalidate
+                return 42
+
+        obj = Cls()
+
+        top_invalidate = mock.Mock()
+        r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+        self.assertEqual(r, 42)
+        obj.invalidate()
+        top_invalidate.assert_called_once()
+
 
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     """More tests for @cached