summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py30
1 files changed, 24 insertions, 6 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 0033051849..88e56e3302 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,6 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -36,9 +37,12 @@ _CacheSentinel = object()
 
 class Cache(object):
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+    def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
         if lru:
-            self.cache = LruCache(max_size=max_entries)
+            cache_type = TreeCache if tree else dict
+            self.cache = LruCache(
+                max_size=max_entries, keylen=keylen, cache_type=cache_type
+            )
             self.max_entries = None
         else:
             self.cache = OrderedDict()
@@ -99,6 +103,15 @@ class Cache(object):
         self.sequence += 1
         self.cache.pop(key, None)
 
+    def invalidate_many(self, key):
+        self.check_thread()
+        if not isinstance(key, tuple):
+            raise TypeError(
+                "The cache key must be a tuple not %r" % (type(key),)
+            )
+        self.sequence += 1
+        self.cache.del_multi(key)
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -122,7 +135,7 @@ class CacheDescriptor(object):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
         self.orig = orig
 
@@ -134,6 +147,7 @@ class CacheDescriptor(object):
         self.max_entries = max_entries
         self.num_args = num_args
         self.lru = lru
+        self.tree = tree
 
         self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
 
@@ -149,6 +163,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
+            tree=self.tree,
         )
 
     def __get__(self, obj, objtype=None):
@@ -200,6 +215,7 @@ class CacheDescriptor(object):
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.invalidate_many = self.cache.invalidate_many
         wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped
@@ -321,21 +337,23 @@ class CacheListDescriptor(object):
         return wrapped
 
 
-def cached(max_entries=1000, num_args=1, lru=True):
+def cached(max_entries=1000, num_args=1, lru=True, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru
+        lru=lru,
+        tree=tree,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         lru=lru,
+        tree=tree,
         inlineCallbacks=True,
     )