summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..d082c26b1f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -42,6 +42,13 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
+def deferred_size(deferred):
+    if deferred.called:
+        return len(deferred.result)
+    else:
+        return 1
+
+
 class Cache(object):
     __slots__ = (
         "cache",
@@ -53,10 +60,11 @@ class Cache(object):
         "metrics",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type
+            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            size_callback=deferred_size if iterable else None,
         )
 
         self.name = name
@@ -155,7 +163,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False, cache_context=False):
+                 inlineCallbacks=False, cache_context=False, iterable=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -169,6 +177,8 @@ class CacheDescriptor(object):
         self.num_args = num_args
         self.tree = tree
 
+        self.iterable = iterable
+
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
@@ -203,6 +213,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
+            iterable=self.iterable,
         )
 
         @functools.wraps(self.orig)
@@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+           iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+                          iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
         tree=tree,
         inlineCallbacks=True,
         cache_context=cache_context,
+        iterable=iterable,
     )