diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cacd7e45fa..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +17,27 @@
from functools import wraps
import threading
+from synapse.util.caches.treecache import TreeCache
+
+
+def enumerate_leaves(node, depth):
+ if depth == 0:
+ yield node
+ else:
+ for n in node.values():
+ for m in enumerate_leaves(n, depth - 1):
+ yield m
+
class LruCache(object):
- """Least-recently-used cache."""
- def __init__(self, max_size):
- cache = {}
+ """
+ Least-recently-used cache.
+ Supports del_multi only if cache_type=TreeCache
+ If cache_type=TreeCache, all keys must be tuples.
+ """
+ def __init__(self, max_size, keylen=1, cache_type=dict):
+ cache = cache_type()
+ self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -62,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None):
@@ -82,7 +97,9 @@ class LruCache(object):
else:
add_node(key, value)
if len(cache) > max_size:
- delete_node(list_root[PREV])
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
@synchronized
def cache_set_default(key, value):
@@ -92,7 +109,9 @@ class LruCache(object):
else:
add_node(key, value)
if len(cache) > max_size:
- delete_node(list_root[PREV])
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
return value
@synchronized
@@ -100,11 +119,23 @@ class LruCache(object):
node = cache.get(key, None)
if node:
delete_node(node)
+ cache.pop(node[KEY], None)
return node[VALUE]
else:
return default
@synchronized
+ def cache_del_multi(key):
+ """
+ This will only work if constructed with cache_type=TreeCache
+ """
+ popped = cache.pop(key)
+ if popped is None:
+ return
+ for leaf in enumerate_leaves(popped, keylen - len(key)):
+ delete_node(leaf)
+
+ @synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
@@ -123,6 +154,8 @@ class LruCache(object):
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ if cache_type is TreeCache:
+ self.del_multi = cache_del_multi
self.len = cache_len
self.contains = cache_contains
self.clear = cache_clear
|