diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
# limitations under the License.
import logging
-from collections import namedtuple
+import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = await self._get_rules_for_room(room_id)
+ rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = await rules_for_room.get_rules(event, context)
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
return rules_by_user
- @cached()
+ @lru_cache()
def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id
@@ -275,12 +276,14 @@ class RulesForRoom:
the entire cache for the room.
"""
- def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+ def __init__(
+ self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ ):
"""
Args:
hs (HomeServer)
room_id (str)
- rules_for_room_cache(Cache): The cache object that caches these
+ rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric)
"""
@@ -489,13 +492,21 @@ class RulesForRoom:
self.state_group = state_group
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+ # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+ # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+ # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+ # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+ # same `cache` and `room_id`.
+ #
+ # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+ # set `frozen=True`.
+
+ cache = attr.ib(type=LruCache)
+ room_id = attr.ib(type=str)
+
def __call__(self):
- rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import functools
import inspect
import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
from weakref import WeakValueDictionary
from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+ def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context
+ self.cache_key_builder = get_cache_key_builder(
+ self.arg_names, self.arg_defaults
+ )
+
+
+class _LruCachedFunction(Generic[F]):
+ cache = None # type: LruCache[CacheKey, Any]
+ __call__ = None # type: F
+
+
+def lru_cache(
+ max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+ """A method decorator that applies a memoizing cache around the function.
+
+ This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+ that the signature is slightly different.
+
+ The main differences with functools.lru_cache are:
+ (a) the size of the cache can be controlled via the cache_factor mechanism
+ (b) the wrapped function can request a "cache_context" which provides a
+ callback mechanism to indicate that the result is no longer valid
+ (c) prometheus metrics are exposed automatically.
+
+ The function should take zero or more arguments, which are used as the key for the
+ cache. Single-argument functions use that argument as the cache key; otherwise the
+ arguments are built into a tuple.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example:
+
+ @lru_cache(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+ return r1 + r2
+
+ The wrapped function also has a 'cache' property which offers direct access to the
+ underlying LruCache.
+ """
+
+ def func(orig: F) -> _LruCachedFunction[F]:
+ desc = LruCacheDescriptor(
+ orig, max_entries=max_entries, cache_context=cache_context,
+ )
+ return cast(_LruCachedFunction[F], desc)
+
+ return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+ """Helper for @lru_cache"""
+
+ class _Sentinel(enum.Enum):
+ sentinel = object()
+
+ def __init__(
+ self, orig, max_entries: int = 1000, cache_context: bool = False,
+ ):
+ super().__init__(orig, num_args=None, cache_context=cache_context)
+ self.max_entries = max_entries
+
+ def __get__(self, obj, owner):
+ cache = LruCache(
+ cache_name=self.orig.__name__, max_size=self.max_entries,
+ ) # type: LruCache[CacheKey, Any]
+
+ get_cache_key = self.cache_key_builder
+ sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+ @functools.wraps(self.orig)
+ def _wrapped(*args, **kwargs):
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+ callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+ cache_key = get_cache_key(args, kwargs)
-class CacheDescriptor(_CacheDescriptorBase):
+ ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+ if ret != sentinel:
+ return ret
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+ ret2 = self.orig(obj, *args, **kwargs)
+ cache.set(cache_key, ret2, callbacks=callbacks)
+
+ return ret2
+
+ wrapped = cast(_CachedFunction, _wrapped)
+ wrapped.cache = cache
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False,
iterable=False,
):
-
super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]
- def get_cache_key_gen(args, kwargs):
- """Given some args/kwargs return a generator that resolves into
- the cache_key.
-
- We loop through each arg name, looking up if its in the `kwargs`,
- otherwise using the next argument in `args`. If there are no more
- args then we try looking the arg name up in the defaults
- """
- pos = 0
- for nm in self.arg_names:
- if nm in kwargs:
- yield kwargs[nm]
- elif pos < len(args):
- yield args[pos]
- pos += 1
- else:
- yield self.arg_defaults[nm]
-
- # By default our cache key is a tuple, but if there is only one item
- # then don't bother wrapping in a tuple. This is to save memory.
- if self.num_args == 1:
- nm = self.arg_names[0]
-
- def get_cache_key(args, kwargs):
- if nm in kwargs:
- return kwargs[nm]
- elif len(args):
- return args[0]
- else:
- return self.arg_defaults[nm]
-
- else:
-
- def get_cache_key(args, kwargs):
- return tuple(get_cache_key_gen(args, kwargs))
+ get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
on a lower level.
"""
+ Cache = Union[DeferredCache, LruCache]
+
_cache_context_objects = (
WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+ ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
- def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
+ def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
@@ -396,8 +475,8 @@ class _CacheContext:
@classmethod
def get_instance(
- cls, cache, cache_key
- ): # type: (DeferredCache, CacheKey) -> _CacheContext
+ cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+ ) -> "_CacheContext":
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
- func = lambda orig: CacheDescriptor(
+ func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
- func = lambda orig: CacheListDescriptor(
+ func = lambda orig: DeferredCacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
)
return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+ param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+ """Construct a function which will build cache keys suitable for a cached function
+
+ Args:
+ param_names: list of formal parameter names for the cached function
+ param_defaults: a mapping from parameter name to default value for that param
+
+ Returns:
+ A function which will take an (args, kwargs) pair and return a cache key
+ """
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+
+ if len(param_names) == 1:
+ nm = param_names[0]
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return param_defaults[nm]
+
+ else:
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+ return get_cache_key
+
+
+def _get_cache_key_gen(
+ param_names: Iterable[str],
+ param_defaults: Mapping[str, Any],
+ args: Sequence[Any],
+ kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ This is essentially the same operation as `inspect.getcallargs`, but optimised so
+ that we don't need to inspect the target function for each call.
+ """
+
+ # We loop through each arg name, looking up if its in the `kwargs`,
+ # otherwise using the next argument in `args`. If there are no more
+ # args then we try looking the arg name up in the defaults.
+ pos = 0
+ for nm in param_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield param_defaults[nm]
|