diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..d082c26b1f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -42,6 +42,13 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+def deferred_size(deferred):
+ if deferred.called:
+ return len(deferred.result)
+ else:
+ return 1
+
+
class Cache(object):
__slots__ = (
"cache",
@@ -53,10 +60,11 @@ class Cache(object):
"metrics",
)
- def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
+ max_size=max_entries, keylen=keylen, cache_type=cache_type,
+ size_callback=deferred_size if iterable else None,
)
self.name = name
@@ -155,7 +163,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
- inlineCallbacks=False, cache_context=False):
+ inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -169,6 +177,8 @@ class CacheDescriptor(object):
self.num_args = num_args
self.tree = tree
+ self.iterable = iterable
+
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
@@ -203,6 +213,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
+ iterable=self.iterable,
)
@functools.wraps(self.orig)
@@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+ iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
cache_context=cache_context,
+ iterable=iterable,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+ iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
+ iterable=iterable,
)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 9c4c679175..00ddf38290 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict):
+ def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -58,6 +58,18 @@ class LruCache(object):
lock = threading.Lock()
+ def cache_len():
+ if size_callback is not None:
+ return sum(size_callback(node.value) for node in cache.itervalues())
+ else:
+ return len(cache)
+
+ def evict():
+ while cache_len() > max_size:
+ todelete = list_root.prev_node
+ delete_node(todelete)
+ cache.pop(todelete.key, None)
+
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
@@ -127,22 +139,18 @@ class LruCache(object):
else:
callbacks = set()
add_node(key, value, callbacks)
- if len(cache) > max_size:
- todelete = list_root.prev_node
- delete_node(todelete)
- cache.pop(todelete.key, None)
+
+ evict()
@synchronized
def cache_set_default(key, value):
node = cache.get(key, None)
if node is not None:
+ evict() # As the new node may be bigger than the old node.
return node.value
else:
add_node(key, value)
- if len(cache) > max_size:
- todelete = list_root.prev_node
- delete_node(todelete)
- cache.pop(todelete.key, None)
+ evict()
return value
@synchronized
@@ -176,10 +184,6 @@ class LruCache(object):
cache.clear()
@synchronized
- def cache_len():
- return len(cache)
-
- @synchronized
def cache_contains(key):
return key in cache
@@ -190,7 +194,7 @@ class LruCache(object):
self.pop = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
- self.len = cache_len
+ self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear
|