summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py32
1 files changed, 18 insertions, 14 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 9c4c679175..00ddf38290 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,7 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -58,6 +58,18 @@ class LruCache(object):
 
         lock = threading.Lock()
 
+        def cache_len():
+            if size_callback is not None:
+                return sum(size_callback(node.value) for node in cache.itervalues())
+            else:
+                return len(cache)
+
+        def evict():
+            while cache_len() > max_size:
+                todelete = list_root.prev_node
+                delete_node(todelete)
+                cache.pop(todelete.key, None)
+
         def synchronized(f):
             @wraps(f)
             def inner(*args, **kwargs):
@@ -127,22 +139,18 @@ class LruCache(object):
                 else:
                     callbacks = set()
                 add_node(key, value, callbacks)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+
+            evict()
 
         @synchronized
         def cache_set_default(key, value):
             node = cache.get(key, None)
             if node is not None:
+                evict()  # As the new node may be bigger than the old node.
                 return node.value
             else:
                 add_node(key, value)
-                if len(cache) > max_size:
-                    todelete = list_root.prev_node
-                    delete_node(todelete)
-                    cache.pop(todelete.key, None)
+                evict()
                 return value
 
         @synchronized
@@ -176,10 +184,6 @@ class LruCache(object):
             cache.clear()
 
         @synchronized
-        def cache_len():
-            return len(cache)
-
-        @synchronized
         def cache_contains(key):
             return key in cache
 
@@ -190,7 +194,7 @@ class LruCache(object):
         self.pop = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
-        self.len = cache_len
+        self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear