summary refs log tree commit diff
path: root/synapse/util/caches/treecache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/treecache.py')
-rw-r--r--synapse/util/caches/treecache.py104
1 files changed, 66 insertions, 38 deletions
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index eb4d98f683..73502a8b06 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,18 +1,43 @@
-from typing import Dict
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 SENTINEL = object()
 
 
+class TreeCacheNode(dict):
+    """The type of nodes in our tree.
+
+    Has its own type so we can distinguish it from real dicts that are stored at the
+    leaves.
+    """
+
+    pass
+
+
 class TreeCache:
     """
     Tree-based backing store for LruCache. Allows subtrees of data to be deleted
     efficiently.
     Keys must be tuples.
+
+    The data structure is a chain of TreeCacheNodes:
+        root = {key_1: {key_2: _value}}
     """
 
     def __init__(self):
         self.size = 0
-        self.root = {}  # type: Dict
+        self.root = TreeCacheNode()
 
     def __setitem__(self, key, value):
         return self.set(key, value)
@@ -21,10 +46,23 @@ class TreeCache:
         return self.get(key, SENTINEL) is not SENTINEL
 
     def set(self, key, value):
+        if isinstance(value, TreeCacheNode):
+            # this would mean we couldn't tell where our tree ended and the value
+            # started.
+            raise ValueError("Cannot store TreeCacheNodes in a TreeCache")
+
         node = self.root
         for k in key[:-1]:
-            node = node.setdefault(k, {})
-        node[key[-1]] = _Entry(value)
+            next_node = node.get(k, SENTINEL)
+            if next_node is SENTINEL:
+                next_node = node[k] = TreeCacheNode()
+            elif not isinstance(next_node, TreeCacheNode):
+                # this suggests that the caller is not being consistent with its key
+                # length.
+                raise ValueError("value conflicts with an existing subtree")
+            node = next_node
+
+        node[key[-1]] = value
         self.size += 1
 
     def get(self, key, default=None):
@@ -33,25 +71,41 @@ class TreeCache:
             node = node.get(k, None)
             if node is None:
                 return default
-        return node.get(key[-1], _Entry(default)).value
+        return node.get(key[-1], default)
 
     def clear(self):
         self.size = 0
-        self.root = {}
+        self.root = TreeCacheNode()
 
     def pop(self, key, default=None):
+        """Remove the given key, or subkey, from the cache
+
+        Args:
+            key: key or subkey to remove.
+            default: value to return if key is not found
+
+        Returns:
+            If the key is not found, 'default'. If the key is complete, the removed
+            value. If the key is partial, the TreeCacheNode corresponding to the part
+            of the tree that was removed.
+        """
+        # a list of the nodes we have touched on the way down the tree
         nodes = []
 
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
-            nodes.append(node)  # don't add the root node
             if node is None:
                 return default
+            if not isinstance(node, TreeCacheNode):
+                # we've gone off the end of the tree
+                raise ValueError("pop() key too long")
+            nodes.append(node)  # don't add the root node
         popped = node.pop(key[-1], SENTINEL)
         if popped is SENTINEL:
             return default
 
+        # working back up the tree, clear out any nodes that are now empty
         node_and_keys = list(zip(nodes, key))
         node_and_keys.reverse()
         node_and_keys.append((self.root, None))
@@ -61,14 +115,15 @@ class TreeCache:
 
             if n:
                 break
+            # found an empty node: remove it from its parent, and loop.
             node_and_keys[i + 1][0].pop(k)
 
-        popped, cnt = _strip_and_count_entires(popped)
+        cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
         self.size -= cnt
         return popped
 
     def values(self):
-        return list(iterate_tree_cache_entry(self.root))
+        return iterate_tree_cache_entry(self.root)
 
     def __len__(self):
         return self.size
@@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d):
     """Helper function to iterate over the leaves of a tree, i.e. a dict of that
     can contain dicts.
     """
-    if isinstance(d, dict):
+    if isinstance(d, TreeCacheNode):
         for value_d in d.values():
             for value in iterate_tree_cache_entry(value_d):
                 yield value
     else:
-        if isinstance(d, _Entry):
-            yield d.value
-        else:
-            yield d
-
-
-class _Entry:
-    __slots__ = ["value"]
-
-    def __init__(self, value):
-        self.value = value
-
-
-def _strip_and_count_entires(d):
-    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
-    value or a dictionary with _Entry's replaced by their values.
-
-    Also returns the count of _Entry's
-    """
-    if isinstance(d, dict):
-        cnt = 0
-        for key, value in d.items():
-            v, n = _strip_and_count_entires(value)
-            d[key] = v
-            cnt += n
-        return d, cnt
-    else:
-        return d.value, 1
+        yield d