diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 27ea65a0f6..6fa63f052e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -55,10 +55,14 @@ cache_counter = metrics.register_cache(
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1):
- self.cache = OrderedDict()
+ def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
- self.max_entries = max_entries
self.name = name
self.keylen = keylen
@@ -82,8 +86,9 @@ class Cache(object):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
- while len(self.cache) > self.max_entries:
- self.cache.popitem(last=False)
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
self.cache[keyargs] = value
@@ -94,9 +99,7 @@ class Cache(object):
self.cache.pop(keyargs, None)
-# TODO(paul):
-# * consider other eviction strategies - LRU?
-def cached(max_entries=1000, num_args=1):
+def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@@ -115,6 +118,7 @@ def cached(max_entries=1000, num_args=1):
name=orig.__name__,
max_entries=max_entries,
keylen=num_args,
+ lru=lru,
)
@functools.wraps(orig)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index b6853ba2d4..96caf8c4c1 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -69,6 +69,28 @@ class CacheTestCase(unittest.TestCase):
cache.get(2)
cache.get(3)
+ def test_eviction_lru(self):
+ cache = Cache("test", max_entries=2, lru=True)
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ cache.prefill(3, "three")
+
+ failed = False
+ try:
+ cache.get(2)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(1)
+ cache.get(3)
+
class CacheDecoratorTestCase(unittest.TestCase):
|