summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/roommember.py3
-rw-r--r--synapse/storage/state.py2
-rw-r--r--synapse/util/caches/descriptors.py25
-rw-r--r--synapse/util/caches/lrucache.py32
-rw-r--r--tests/util/test_lrucache.py25
5 files changed, 66 insertions, 21 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 5d18037c7c..e63aab6ccf 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore):
             room_id, state_group, state_ids,
         )
 
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
+                           max_entries=2000)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
                                        cache_context, event=None):
         # We don't use `state_group`, it's there so that we can cache based
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7f466c40ac..c480743f89 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -284,7 +284,7 @@ class StateStore(SQLBaseStore):
             return [r[0] for r in results]
         return self.runInteraction("get_current_state_for_key", f)
 
-    @cached(num_args=2, max_entries=1000)
+    @cached(num_args=2, max_entries=1000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..d082c26b1f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -42,6 +42,13 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
+def deferred_size(deferred):
+    if deferred.called:
+        return len(deferred.result)
+    else:
+        return 1
+
+
 class Cache(object):
     __slots__ = (
         "cache",
@@ -53,10 +60,11 @@ class Cache(object):
         "metrics",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type
+            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            size_callback=deferred_size if iterable else None,
         )
 
         self.name = name
@@ -155,7 +163,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False, cache_context=False):
+                 inlineCallbacks=False, cache_context=False, iterable=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -169,6 +177,8 @@ class CacheDescriptor(object):
         self.num_args = num_args
         self.tree = tree
 
+        self.iterable = iterable
+
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
@@ -203,6 +213,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
+            iterable=self.iterable,
         )
 
         @functools.wraps(self.orig)
@@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+           iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+                          iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
         tree=tree,
         inlineCallbacks=True,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 9c4c679175..00ddf38290 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,7 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -58,6 +58,18 @@ class LruCache(object):
 
         lock = threading.Lock()
 
+        def cache_len():
+            if size_callback is not None:
+                return sum(size_callback(node.value) for node in cache.itervalues())
+            else:
+                return len(cache)
+
+        def evict():
+            while cache_len() > max_size:
+                todelete = list_root.prev_node
+                delete_node(todelete)
+                cache.pop(todelete.key, None)
+
         def synchronized(f):
             @wraps(f)
             def inner(*args, **kwargs):
@@ -127,22 +139,18 @@ class LruCache(object):
                 else:
                     callbacks = set()
                 add_node(key, value, callbacks)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+
+            evict()
 
         @synchronized
         def cache_set_default(key, value):
             node = cache.get(key, None)
             if node is not None:
+                evict()  # As the new node may be bigger than the old node.
                 return node.value
             else:
                 add_node(key, value)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+                evict()
                 return value
 
         @synchronized
@@ -176,10 +184,6 @@ class LruCache(object):
             cache.clear()
 
         @synchronized
-        def cache_len():
-            return len(cache)
-
-        @synchronized
         def cache_contains(key):
             return key in cache
 
@@ -190,7 +194,7 @@ class LruCache(object):
         self.pop = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
-        self.len = cache_len
+        self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
 
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 1eba5b535e..d888a64d0a 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -232,3 +232,28 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
         self.assertEquals(m1.call_count, 1)
         self.assertEquals(m2.call_count, 0)
         self.assertEquals(m3.call_count, 1)
+
+
+class LruCacheSizedTestCase(unittest.TestCase):
+
+    def test_evict(self):
+        cache = LruCache(5, size_callback=len)
+        cache["key1"] = [0]
+        cache["key2"] = [1, 2]
+        cache["key3"] = [3]
+        cache["key4"] = [4]
+
+        self.assertEquals(cache["key1"], [0])
+        self.assertEquals(cache["key2"], [1, 2])
+        self.assertEquals(cache["key3"], [3])
+        self.assertEquals(cache["key4"], [4])
+        self.assertEquals(len(cache), 5)
+
+        cache["key5"] = [5, 6]
+
+        self.assertEquals(len(cache), 4)
+        self.assertEquals(cache.get("key1"), None)
+        self.assertEquals(cache.get("key2"), None)
+        self.assertEquals(cache["key3"], [3])
+        self.assertEquals(cache["key4"], [4])
+        self.assertEquals(cache["key5"], [5, 6])