summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-01-29 10:44:46 +0000
committerErik Johnston <erik@matrix.org>2016-01-29 10:44:46 +0000
commita30364c1f99bdd7d5cb0fe82ebdfe52d996defef (patch)
tree0d292385a4d58f43ece0f7a8f54ac005f0e7aa6b /synapse
parentMake TreeCache keep track of its own size. (diff)
downloadsynapse-a30364c1f99bdd7d5cb0fe82ebdfe52d996defef.tar.xz
Correctly bookkeep the size of TreeCache
Diffstat (limited to '')
-rw-r--r--synapse/util/caches/treecache.py31
1 files changed, 28 insertions, 3 deletions
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index a29ea8144e..3331ea9eba 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -21,7 +21,7 @@ class TreeCache(object):
         node = self.root
         for k in key[:-1]:
             node = node.setdefault(k, {})
-        node[key[-1]] = value
+        node[key[-1]] = _Entry(value)
         self.size += 1
 
     def get(self, key, default=None):
@@ -30,7 +30,7 @@ class TreeCache(object):
             node = node.get(k, None)
             if node is None:
                 return default
-        return node.get(key[-1], default)
+        return node.get(key[-1], _Entry(default)).value
 
     def clear(self):
         self.size = 0
@@ -60,8 +60,33 @@ class TreeCache(object):
                 break
             node_and_keys[i+1][0].pop(k)
 
-        self.size -= 1
+        popped, cnt = _strip_and_count_entires(popped)
+        self.size -= cnt
         return popped
 
     def __len__(self):
         return self.size
+
+
+class _Entry(object):
+    __slots__ = ["value"]
+
+    def __init__(self, value):
+        object.__setattr__(self, "value", value)
+
+
+def _strip_and_count_entires(d):
+    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+    value or a dictionary with _Entry's replaced by their values.
+
+    Also returns the count of _Entry's
+    """
+    if isinstance(d, dict):
+        cnt = 0
+        for key, value in d.items():
+            v, n = _strip_and_count_entires(value)
+            d[key] = v
+            cnt += n
+        return d, cnt
+    else:
+        return d.value, 1