diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index ab6095564a..8361dd8cee 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
- @cached(max_entries=2)
+ @cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key):
callcount[0] += 1
return key
@@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
+ yield a.func2("foo")
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
new file mode 100644
index 0000000000..31d24adb8b
--- /dev/null
+++ b/tests/util/test_expiring_cache.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .. import unittest
+
+from synapse.util.caches.expiringcache import ExpiringCache
+
+from tests.utils import MockClock
+
+
+class ExpiringCacheTestCase(unittest.TestCase):
+
+ def test_get_set(self):
+ clock = MockClock()
+ cache = ExpiringCache("test", clock, max_len=1)
+
+ cache["key"] = "value"
+ self.assertEquals(cache.get("key"), "value")
+ self.assertEquals(cache["key"], "value")
+
+ def test_eviction(self):
+ clock = MockClock()
+ cache = ExpiringCache("test", clock, max_len=2)
+
+ cache["key"] = "value"
+ cache["key2"] = "value2"
+ self.assertEquals(cache.get("key"), "value")
+ self.assertEquals(cache.get("key2"), "value2")
+
+ cache["key3"] = "value3"
+ self.assertEquals(cache.get("key"), None)
+ self.assertEquals(cache.get("key2"), "value2")
+ self.assertEquals(cache.get("key3"), "value3")
+
+ def test_iterable_eviction(self):
+ clock = MockClock()
+ cache = ExpiringCache("test", clock, max_len=5, iterable=True)
+
+ cache["key"] = [1]
+ cache["key2"] = [2, 3]
+ cache["key3"] = [4, 5]
+
+ self.assertEquals(cache.get("key"), [1])
+ self.assertEquals(cache.get("key2"), [2, 3])
+ self.assertEquals(cache.get("key3"), [4, 5])
+
+ cache["key4"] = [6, 7]
+ self.assertEquals(cache.get("key"), None)
+ self.assertEquals(cache.get("key2"), None)
+ self.assertEquals(cache.get("key3"), [4, 5])
+ self.assertEquals(cache.get("key4"), [6, 7])
+
+ def test_time_eviction(self):
+ clock = MockClock()
+ cache = ExpiringCache("test", clock, expiry_ms=1000)
+ cache.start()
+
+ cache["key"] = 1
+ clock.advance_time(0.5)
+ cache["key2"] = 2
+
+ self.assertEquals(cache.get("key"), 1)
+ self.assertEquals(cache.get("key2"), 2)
+
+ clock.advance_time(0.9)
+ self.assertEquals(cache.get("key"), None)
+ self.assertEquals(cache.get("key2"), 2)
+
+ clock.advance_time(1)
+ self.assertEquals(cache.get("key"), None)
+ self.assertEquals(cache.get("key2"), None)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 1eba5b535e..dfb78cb8bd 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
- cache.get("key", callback=m)
+ cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", "value")
@@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
- cache.get("key", callback=m)
+ cache.get("key", callbacks=[m])
self.assertFalse(m.called)
- cache.get("key", callback=m)
+ cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value2")
@@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
- cache.set("key", "value", m)
+ cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value")
@@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
- cache.set("key", "value", m)
+ cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.pop("key")
@@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
- cache.set(("a", "1"), "value", m1)
- cache.set(("a", "2"), "value", m2)
- cache.set(("b", "1"), "value", m3)
- cache.set(("b", "2"), "value", m4)
+ cache.set(("a", "1"), "value", callbacks=[m1])
+ cache.set(("a", "2"), "value", callbacks=[m2])
+ cache.set(("b", "1"), "value", callbacks=[m3])
+ cache.set(("b", "2"), "value", callbacks=[m4])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock()
cache = LruCache(5)
- cache.set("key1", "value", m1)
- cache.set("key2", "value", m2)
+ cache.set("key1", "value", callbacks=[m1])
+ cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3")
cache = LruCache(2)
- cache.set("key1", "value", m1)
- cache.set("key2", "value", m2)
+ cache.set("key1", "value", callbacks=[m1])
+ cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
- cache.set("key3", "value", m3)
+ cache.set("key3", "value", callbacks=[m3])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
@@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
- cache.set("key1", "value", m1)
+ cache.set("key1", "value", callbacks=[m1])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)
+
+
+class LruCacheSizedTestCase(unittest.TestCase):
+
+ def test_evict(self):
+ cache = LruCache(5, size_callback=len)
+ cache["key1"] = [0]
+ cache["key2"] = [1, 2]
+ cache["key3"] = [3]
+ cache["key4"] = [4]
+
+ self.assertEquals(cache["key1"], [0])
+ self.assertEquals(cache["key2"], [1, 2])
+ self.assertEquals(cache["key3"], [3])
+ self.assertEquals(cache["key4"], [4])
+ self.assertEquals(len(cache), 5)
+
+ cache["key5"] = [5, 6]
+
+ self.assertEquals(len(cache), 4)
+ self.assertEquals(cache.get("key1"), None)
+ self.assertEquals(cache.get("key2"), None)
+ self.assertEquals(cache["key3"], [3])
+ self.assertEquals(cache["key4"], [4])
+ self.assertEquals(cache["key5"], [5, 6])
|