summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
authorDavid Baker <dbkr@users.noreply.github.com>2016-01-22 14:47:48 +0000
committerDavid Baker <dbkr@users.noreply.github.com>2016-01-22 14:47:48 +0000
commit7a3fe48ba48d99f02b4bfd556199571f95fc8e1c (patch)
tree740e160954c47d5362c22661bb141c31cdfbf118 /synapse/util/caches
parentMerge pull request #517 from matrix-org/erikj/push_only_room (diff)
parentDon't add the member functiopn if we're not using treecache (diff)
downloadsynapse-7a3fe48ba48d99f02b4bfd556199571f95fc8e1c.tar.xz
Merge pull request #519 from matrix-org/dbkr/treecache
Make LRU caching tree-based so subtrees of the cache can be invalidated cheaply.
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/descriptors.py30
-rw-r--r--synapse/util/caches/lrucache.py53
-rw-r--r--synapse/util/caches/treecache.py60
3 files changed, 128 insertions, 15 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 0033051849..88e56e3302 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,6 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -36,9 +37,12 @@ _CacheSentinel = object()
 
 class Cache(object):
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+    def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
         if lru:
-            self.cache = LruCache(max_size=max_entries)
+            cache_type = TreeCache if tree else dict
+            self.cache = LruCache(
+                max_size=max_entries, keylen=keylen, cache_type=cache_type
+            )
             self.max_entries = None
         else:
             self.cache = OrderedDict()
@@ -99,6 +103,15 @@ class Cache(object):
         self.sequence += 1
         self.cache.pop(key, None)
 
+    def invalidate_many(self, key):
+        self.check_thread()
+        if not isinstance(key, tuple):
+            raise TypeError(
+                "The cache key must be a tuple not %r" % (type(key),)
+            )
+        self.sequence += 1
+        self.cache.del_multi(key)
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -122,7 +135,7 @@ class CacheDescriptor(object):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
         self.orig = orig
 
@@ -134,6 +147,7 @@ class CacheDescriptor(object):
         self.max_entries = max_entries
         self.num_args = num_args
         self.lru = lru
+        self.tree = tree
 
         self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
 
@@ -149,6 +163,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
+            tree=self.tree,
         )
 
     def __get__(self, obj, objtype=None):
@@ -200,6 +215,7 @@ class CacheDescriptor(object):
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.invalidate_many = self.cache.invalidate_many
         wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped
@@ -321,21 +337,23 @@ class CacheListDescriptor(object):
         return wrapped
 
 
-def cached(max_entries=1000, num_args=1, lru=True):
+def cached(max_entries=1000, num_args=1, lru=True, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru
+        lru=lru,
+        tree=tree,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         lru=lru,
+        tree=tree,
         inlineCallbacks=True,
     )
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 0122b0bb3f..e6a66dc041 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,11 +17,27 @@
 from functools import wraps
 import threading
 
+from synapse.util.caches.treecache import TreeCache
+
+
+def enumerate_leaves(node, depth):
+    if depth == 0:
+        yield node
+    else:
+        for n in node.values():
+            for m in enumerate_leaves(n, depth - 1):
+                yield m
+
 
 class LruCache(object):
-    """Least-recently-used cache."""
-    def __init__(self, max_size):
-        cache = {}
+    """
+    Least-recently-used cache.
+    Supports del_multi only if cache_type=TreeCache
+    If cache_type=TreeCache, all keys must be tuples.
+    """
+    def __init__(self, max_size, keylen=1, cache_type=dict):
+        cache = cache_type()
+        self.size = 0
         list_root = []
         list_root[:] = [list_root, list_root, None, None]
 
@@ -44,6 +60,7 @@ class LruCache(object):
             prev_node[NEXT] = node
             next_node[PREV] = node
             cache[key] = node
+            self.size += 1
 
         def move_node_to_front(node):
             prev_node = node[PREV]
@@ -62,7 +79,7 @@ class LruCache(object):
             next_node = node[NEXT]
             prev_node[NEXT] = next_node
             next_node[PREV] = prev_node
-            cache.pop(node[KEY], None)
+            self.size -= 1
 
         @synchronized
         def cache_get(key, default=None):
@@ -81,8 +98,10 @@ class LruCache(object):
                 node[VALUE] = value
             else:
                 add_node(key, value)
-                if len(cache) > max_size:
-                    delete_node(list_root[PREV])
+                if self.size > max_size:
+                    todelete = list_root[PREV]
+                    delete_node(todelete)
+                    cache.pop(todelete[KEY], None)
 
         @synchronized
         def cache_set_default(key, value):
@@ -91,8 +110,10 @@ class LruCache(object):
                 return node[VALUE]
             else:
                 add_node(key, value)
-                if len(cache) > max_size:
-                    delete_node(list_root[PREV])
+                if self.size > max_size:
+                    todelete = list_root[PREV]
+                    delete_node(todelete)
+                    cache.pop(todelete[KEY], None)
                 return value
 
         @synchronized
@@ -100,11 +121,23 @@ class LruCache(object):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
+                cache.pop(node[KEY], None)
                 return node[VALUE]
             else:
                 return default
 
         @synchronized
+        def cache_del_multi(key):
+            """
+            This will only work if constructed with cache_type=TreeCache
+            """
+            popped = cache.pop(key)
+            if popped is None:
+                return
+            for leaf in enumerate_leaves(popped, keylen - len(key)):
+                delete_node(leaf)
+
+        @synchronized
         def cache_clear():
             list_root[NEXT] = list_root
             list_root[PREV] = list_root
@@ -112,7 +145,7 @@ class LruCache(object):
 
         @synchronized
         def cache_len():
-            return len(cache)
+            return self.size
 
         @synchronized
         def cache_contains(key):
@@ -123,6 +156,8 @@ class LruCache(object):
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        if cache_type is TreeCache:
+            self.del_multi = cache_del_multi
         self.len = cache_len
         self.contains = cache_contains
         self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
new file mode 100644
index 0000000000..3b58860910
--- /dev/null
+++ b/synapse/util/caches/treecache.py
@@ -0,0 +1,60 @@
+SENTINEL = object()
+
+
+class TreeCache(object):
+    """
+    Tree-based backing store for LruCache. Allows subtrees of data to be deleted
+    efficiently.
+    Keys must be tuples.
+    """
+    def __init__(self):
+        self.root = {}
+
+    def __setitem__(self, key, value):
+        return self.set(key, value)
+
+    def __contains__(self, key):
+        return self.get(key, SENTINEL) is not SENTINEL
+
+    def set(self, key, value):
+        node = self.root
+        for k in key[:-1]:
+            node = node.setdefault(k, {})
+        node[key[-1]] = value
+
+    def get(self, key, default=None):
+        node = self.root
+        for k in key[:-1]:
+            node = node.get(k, None)
+            if node is None:
+                return default
+        return node.get(key[-1], default)
+
+    def clear(self):
+        self.root = {}
+
+    def pop(self, key, default=None):
+        nodes = []
+
+        node = self.root
+        for k in key[:-1]:
+            node = node.get(k, None)
+            nodes.append(node)  # don't add the root node
+            if node is None:
+                return default
+        popped = node.pop(key[-1], SENTINEL)
+        if popped is SENTINEL:
+            return default
+
+        node_and_keys = zip(nodes, key)
+        node_and_keys.reverse()
+        node_and_keys.append((self.root, None))
+
+        for i in range(len(node_and_keys) - 1):
+            n, k = node_and_keys[i]
+
+            if n:
+                break
+            node_and_keys[i+1][0].pop(k)
+
+        return popped