diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f7423f2fab..9c4c679175 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -29,19 +29,32 @@ def enumerate_leaves(node, depth):
yield m
+class _Node(object):
+ __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
+
+ def __init__(self, prev_node, next_node, key, value, callbacks=set()):
+ self.prev_node = prev_node
+ self.next_node = next_node
+ self.key = key
+ self.value = value
+ self.callbacks = callbacks
+
+
class LruCache(object):
"""
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
+
+ Can also set callbacks on objects when getting/setting which are fired
+ when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
self.cache = cache # Used for introspection.
- list_root = []
- list_root[:] = [list_root, list_root, None, None]
-
- PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
+ list_root = _Node(None, None, None, None)
+ list_root.next_node = list_root
+ list_root.prev_node = list_root
lock = threading.Lock()
@@ -53,65 +66,83 @@ class LruCache(object):
return inner
- def add_node(key, value):
+ def add_node(key, value, callbacks=set()):
prev_node = list_root
- next_node = prev_node[NEXT]
- node = [prev_node, next_node, key, value]
- prev_node[NEXT] = node
- next_node[PREV] = node
+ next_node = prev_node.next_node
+ node = _Node(prev_node, next_node, key, value, callbacks)
+ prev_node.next_node = node
+ next_node.prev_node = node
cache[key] = node
def move_node_to_front(node):
- prev_node = node[PREV]
- next_node = node[NEXT]
- prev_node[NEXT] = next_node
- next_node[PREV] = prev_node
+ prev_node = node.prev_node
+ next_node = node.next_node
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
prev_node = list_root
- next_node = prev_node[NEXT]
- node[PREV] = prev_node
- node[NEXT] = next_node
- prev_node[NEXT] = node
- next_node[PREV] = node
+ next_node = prev_node.next_node
+ node.prev_node = prev_node
+ node.next_node = next_node
+ prev_node.next_node = node
+ next_node.prev_node = node
def delete_node(node):
- prev_node = node[PREV]
- next_node = node[NEXT]
- prev_node[NEXT] = next_node
- next_node[PREV] = prev_node
+ prev_node = node.prev_node
+ next_node = node.next_node
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
+
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
@synchronized
- def cache_get(key, default=None):
+ def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
- return node[VALUE]
+ if callback:
+ node.callbacks.add(callback)
+ return node.value
else:
return default
@synchronized
- def cache_set(key, value):
+ def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
+ if value != node.value:
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
+ if callback:
+ node.callbacks.add(callback)
+
move_node_to_front(node)
- node[VALUE] = value
+ node.value = value
else:
- add_node(key, value)
+ if callback:
+ callbacks = set([callback])
+ else:
+ callbacks = set()
+ add_node(key, value, callbacks)
if len(cache) > max_size:
- todelete = list_root[PREV]
+ todelete = list_root.prev_node
delete_node(todelete)
- cache.pop(todelete[KEY], None)
+ cache.pop(todelete.key, None)
@synchronized
def cache_set_default(key, value):
node = cache.get(key, None)
if node is not None:
- return node[VALUE]
+ return node.value
else:
add_node(key, value)
if len(cache) > max_size:
- todelete = list_root[PREV]
+ todelete = list_root.prev_node
delete_node(todelete)
- cache.pop(todelete[KEY], None)
+ cache.pop(todelete.key, None)
return value
@synchronized
@@ -119,8 +150,8 @@ class LruCache(object):
node = cache.get(key, None)
if node:
delete_node(node)
- cache.pop(node[KEY], None)
- return node[VALUE]
+ cache.pop(node.key, None)
+ return node.value
else:
return default
@@ -137,8 +168,11 @@ class LruCache(object):
@synchronized
def cache_clear():
- list_root[NEXT] = list_root
- list_root[PREV] = list_root
+ list_root.next_node = list_root
+ list_root.prev_node = list_root
+ for node in cache.values():
+ for cb in node.callbacks:
+ cb()
cache.clear()
@synchronized
|