summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/caches/lrucache.py10
-rw-r--r--synapse/util/caches/treecache.py36
-rw-r--r--tests/util/test_lrucache.py7
-rw-r--r--tests/util/test_treecache.py12
4 files changed, 57 insertions, 8 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e6a66dc041..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -37,7 +37,7 @@ class LruCache(object):
     """
     def __init__(self, max_size, keylen=1, cache_type=dict):
         cache = cache_type()
-        self.size = 0
+        self.cache = cache  # Used for introspection.
         list_root = []
         list_root[:] = [list_root, list_root, None, None]
 
@@ -60,7 +60,6 @@ class LruCache(object):
             prev_node[NEXT] = node
             next_node[PREV] = node
             cache[key] = node
-            self.size += 1
 
         def move_node_to_front(node):
             prev_node = node[PREV]
@@ -79,7 +78,6 @@ class LruCache(object):
             next_node = node[NEXT]
             prev_node[NEXT] = next_node
             next_node[PREV] = prev_node
-            self.size -= 1
 
         @synchronized
         def cache_get(key, default=None):
@@ -98,7 +96,7 @@ class LruCache(object):
                 node[VALUE] = value
             else:
                 add_node(key, value)
-                if self.size > max_size:
+                if len(cache) > max_size:
                     todelete = list_root[PREV]
                     delete_node(todelete)
                     cache.pop(todelete[KEY], None)
@@ -110,7 +108,7 @@ class LruCache(object):
                 return node[VALUE]
             else:
                 add_node(key, value)
-                if self.size > max_size:
+                if len(cache) > max_size:
                     todelete = list_root[PREV]
                     delete_node(todelete)
                     cache.pop(todelete[KEY], None)
@@ -145,7 +143,7 @@ class LruCache(object):
 
         @synchronized
         def cache_len():
-            return self.size
+            return len(cache)
 
         @synchronized
         def cache_contains(key):
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 3b58860910..29d02f7e95 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -8,6 +8,7 @@ class TreeCache(object):
     Keys must be tuples.
     """
     def __init__(self):
+        self.size = 0
         self.root = {}
 
     def __setitem__(self, key, value):
@@ -20,7 +21,8 @@ class TreeCache(object):
         node = self.root
         for k in key[:-1]:
             node = node.setdefault(k, {})
-        node[key[-1]] = value
+        node[key[-1]] = _Entry(value)
+        self.size += 1
 
     def get(self, key, default=None):
         node = self.root
@@ -28,9 +30,10 @@ class TreeCache(object):
             node = node.get(k, None)
             if node is None:
                 return default
-        return node.get(key[-1], default)
+        return node.get(key[-1], _Entry(default)).value
 
     def clear(self):
+        self.size = 0
         self.root = {}
 
     def pop(self, key, default=None):
@@ -57,4 +60,33 @@ class TreeCache(object):
                 break
             node_and_keys[i+1][0].pop(k)
 
+        popped, cnt = _strip_and_count_entires(popped)
+        self.size -= cnt
         return popped
+
+    def __len__(self):
+        return self.size
+
+
+class _Entry(object):
+    __slots__ = ["value"]
+
+    def __init__(self, value):
+        self.value = value
+
+
+def _strip_and_count_entires(d):
+    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+    value or a dictionary with _Entry's replaced by their values.
+
+    Also returns the count of _Entry's
+    """
+    if isinstance(d, dict):
+        cnt = 0
+        for key, value in d.items():
+            v, n = _strip_and_count_entires(value)
+            d[key] = v
+            cnt += n
+        return d, cnt
+    else:
+        return d.value, 1
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 2cd3d26454..bab366fb7f 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,7 @@ from .. import unittest
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
 
+
 class LruCacheTestCase(unittest.TestCase):
 
     def test_get_set(self):
@@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get(("vehicles", "car")), "vroom")
         self.assertEquals(cache.get(("vehicles", "train")), "chuff")
         # Man from del_multi say "Yes".
+
+    def test_clear(self):
+        cache = LruCache(1)
+        cache["key"] = 1
+        cache.clear()
+        self.assertEquals(len(cache), 0)
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 9946ceb3f1..1efbeb6b33 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
         cache[("b",)] = "B"
         self.assertEquals(cache.get(("a",)), "A")
         self.assertEquals(cache.get(("b",)), "B")
+        self.assertEquals(len(cache), 2)
 
     def test_pop_onelevel(self):
         cache = TreeCache()
@@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.pop(("a",)), "A")
         self.assertEquals(cache.pop(("a",)), None)
         self.assertEquals(cache.get(("b",)), "B")
+        self.assertEquals(len(cache), 1)
 
     def test_get_set_twolevel(self):
         cache = TreeCache()
@@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get(("a", "a")), "AA")
         self.assertEquals(cache.get(("a", "b")), "AB")
         self.assertEquals(cache.get(("b", "a")), "BA")
+        self.assertEquals(len(cache), 3)
 
     def test_pop_twolevel(self):
         cache = TreeCache()
@@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get(("a", "b")), "AB")
         self.assertEquals(cache.pop(("b", "a")), "BA")
         self.assertEquals(cache.pop(("b", "a")), None)
+        self.assertEquals(len(cache), 1)
 
     def test_pop_mixedlevel(self):
         cache = TreeCache()
@@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get(("a", "a")), None)
         self.assertEquals(cache.get(("a", "b")), None)
         self.assertEquals(cache.get(("b", "a")), "BA")
+        self.assertEquals(len(cache), 1)
+
+    def test_clear(self):
+        cache = TreeCache()
+        cache[("a",)] = "A"
+        cache[("b",)] = "B"
+        cache.clear()
+        self.assertEquals(len(cache), 0)