diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f31dfb22b7..8dba61d49f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,8 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
-
-from collections import OrderedDict
+from collections import namedtuple
import os
import functools
@@ -54,16 +53,11 @@ class Cache(object):
"metrics",
)
- def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
- if lru:
- cache_type = TreeCache if tree else dict
- self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
- )
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ cache_type = TreeCache if tree else dict
+ self.cache = LruCache(
+ max_size=max_entries, keylen=keylen, cache_type=cache_type
+ )
self.name = name
self.keylen = keylen
@@ -81,8 +75,8 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel):
- val = self.cache.get(key, _CacheSentinel)
+ def get(self, key, default=_CacheSentinel, callback=None):
+ val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -94,19 +88,15 @@ class Cache(object):
else:
return default
- def update(self, sequence, key, value):
+ def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value)
-
- def prefill(self, key, value):
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
+ self.prefill(key, value, callback=callback)
- self.cache[key] = value
+ def prefill(self, key, value, callback=None):
+ self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@@ -151,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example::
+
+ @cachedInlineCallbacks(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
+ defer.returnValue(r1 + r2)
+
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
- inlineCallbacks=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -165,15 +167,33 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
- self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
+ all_args = inspect.getargspec(orig)
+ self.arg_names = all_args.args[1:num_args + 1]
+
+ if "cache_context" in all_args.args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ try:
+ self.arg_names.remove("cache_context")
+ except ValueError:
+ pass
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
+ " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@@ -182,16 +202,29 @@ class CacheDescriptor(object):
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
- lru=self.lru,
tree=self.tree,
)
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
+ # Add temp cache_context so inspect.getcallargs doesn't explode
+ if self.add_cache_context:
+ kwargs["cache_context"] = None
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext(cache, cache_key)
+
try:
- cached_result_d = cache.get(cache_key)
+ cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -228,7 +261,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -297,6 +330,10 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
@@ -311,7 +348,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key))
+ res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@@ -345,7 +382,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- cache.update(sequence, tuple(key), observer)
+ cache.update(
+ sequence, tuple(key), observer,
+ callback=invalidate_callback
+ )
def invalidate(f, key):
cache.invalidate(key)
@@ -376,24 +416,29 @@ class CacheListDescriptor(object):
return wrapped
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+ def invalidate(self):
+ self.cache.invalidate(self.key)
+
+
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
+ cache_context=cache_context,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
inlineCallbacks=True,
+ cache_context=cache_context,
)
|