diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 0033051849..af7bf15500 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -38,7 +38,7 @@ class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
- self.cache = LruCache(max_size=max_entries)
+ self.cache = LruCache(max_size=max_entries, keylen=keylen)
self.max_entries = None
else:
self.cache = OrderedDict()
@@ -99,6 +99,15 @@ class Cache(object):
self.sequence += 1
self.cache.pop(key, None)
+ def invalidate_many(self, key):
+ self.check_thread()
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+ self.sequence += 1
+ self.cache.del_multi(key)
+
def invalidate_all(self):
self.check_thread()
self.sequence += 1
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b7964467eb 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -32,7 +32,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
- self.cache = LruCache(max_size=max_entries)
+ self.cache = LruCache(max_size=max_entries, keylen=1)
self.name = name
self.sequence = 0
@@ -56,7 +56,7 @@ class DictionaryCache(object):
)
def get(self, key, dict_keys=None):
- entry = self.cache.get(key, self.sentinel)
+ entry = self.cache.get((key,), self.sentinel)
if entry is not self.sentinel:
cache_counter.inc_hits(self.name)
@@ -78,7 +78,7 @@ class DictionaryCache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(key, None)
+ self.cache.pop((key,), None)
def invalidate_all(self):
self.check_thread()
@@ -96,8 +96,8 @@ class DictionaryCache(object):
self._update_or_insert(key, value)
def _update_or_insert(self, key, value):
- entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+ entry = self.cache.setdefault((key,), DictionaryEntry(False, {}))
entry.value.update(value)
def _insert(self, key, value):
- self.cache[key] = DictionaryEntry(True, value)
+ self.cache[(key,)] = DictionaryEntry(True, value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 0122b0bb3f..0feceb298a 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,11 +17,23 @@
from functools import wraps
import threading
+from synapse.util.caches.treecache import TreeCache
+
+
+def enumerate_leaves(node, depth):
+ if depth == 0:
+ yield node
+ else:
+ for n in node.values():
+ for m in enumerate_leaves(n, depth - 1):
+ yield m
+
class LruCache(object):
"""Least-recently-used cache."""
- def __init__(self, max_size):
- cache = {}
+ def __init__(self, max_size, keylen):
+ cache = TreeCache()
+ self.size = 0
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -44,6 +56,7 @@ class LruCache(object):
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
+ self.size += 1
def move_node_to_front(node):
prev_node = node[PREV]
@@ -62,7 +75,7 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- cache.pop(node[KEY], None)
+ self.size -= 1
@synchronized
def cache_get(key, default=None):
@@ -81,8 +94,10 @@ class LruCache(object):
node[VALUE] = value
else:
add_node(key, value)
- if len(cache) > max_size:
- delete_node(list_root[PREV])
+ if self.size > max_size:
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
@synchronized
def cache_set_default(key, value):
@@ -91,8 +106,10 @@ class LruCache(object):
return node[VALUE]
else:
add_node(key, value)
- if len(cache) > max_size:
- delete_node(list_root[PREV])
+ if self.size > max_size:
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
return value
@synchronized
@@ -100,11 +117,20 @@ class LruCache(object):
node = cache.get(key, None)
if node:
delete_node(node)
+ cache.pop(node[KEY], None)
return node[VALUE]
else:
return default
@synchronized
+ def cache_del_multi(key):
+ popped = cache.pop(key)
+ if popped is None:
+ return
+ for leaf in enumerate_leaves(popped, keylen - len(key)):
+ delete_node(leaf)
+
+ @synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
@@ -112,7 +138,7 @@ class LruCache(object):
@synchronized
def cache_len():
- return len(cache)
+ return self.size
@synchronized
def cache_contains(key):
@@ -123,6 +149,7 @@ class LruCache(object):
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ self.del_multi = cache_del_multi
self.len = cache_len
self.contains = cache_contains
self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
new file mode 100644
index 0000000000..1e5f87e6ad
--- /dev/null
+++ b/synapse/util/caches/treecache.py
@@ -0,0 +1,52 @@
+SENTINEL = object()
+
+
+class TreeCache(object):
+ def __init__(self):
+ self.root = {}
+
+ def __setitem__(self, key, value):
+ return self.set(key, value)
+
+ def set(self, key, value):
+ node = self.root
+ for k in key[:-1]:
+ node = node.setdefault(k, {})
+ node[key[-1]] = value
+
+ def get(self, key, default=None):
+ node = self.root
+ for k in key[:-1]:
+ node = node.get(k, None)
+ if node is None:
+ return default
+ return node.get(key[-1], default)
+
+ def clear(self):
+ self.root = {}
+
+ def pop(self, key, default=None):
+ nodes = []
+
+ node = self.root
+ for k in key[:-1]:
+ node = node.get(k, None)
+ nodes.append(node) # don't add the root node
+ if node is None:
+ return default
+ popped = node.pop(key[-1], SENTINEL)
+ if popped is SENTINEL:
+ return default
+
+ node_and_keys = zip(nodes, key)
+ node_and_keys.reverse()
+ node_and_keys.append((self.root, None))
+
+ for i in range(len(node_and_keys) - 1):
+ n,k = node_and_keys[i]
+
+ if n:
+ break
+ node_and_keys[i+1][0].pop(k)
+
+ return popped
\ No newline at end of file
|