summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/caches/lrucache.py2
-rw-r--r--synapse/util/caches/treecache.py167
2 files changed, 125 insertions, 44 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6f95c1354e..80671daca3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -417,7 +417,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
+        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache[_Node[KT, VT]]] = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c0136ea269..37dcc040f4 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -12,18 +12,59 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SENTINEL = object()
+from enum import Enum
+from typing import (
+    Any,
+    Dict,
+    Generator,
+    Generic,
+    List,
+    Literal,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    overload,
+)
 
 
-class TreeCacheNode(dict):
+class Sentinel(Enum):
+    sentinel = object()
+
+
+V = TypeVar("V")
+T = TypeVar("T")
+
+
+class TreeCacheNode(Generic[V]):
     """The type of nodes in our tree.
 
-    Has its own type so we can distinguish it from real dicts that are stored at the
-    leaves.
+    Either a leaf node or a branch node.
     """
 
+    __slots__ = ["leaf_value", "sub_tree"]
+
+    def __init__(
+        self,
+        leaf_value: Union[V, Literal[Sentinel.sentinel]] = Sentinel.sentinel,
+        sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = None,
+    ) -> None:
+        if leaf_value is Sentinel.sentinel and sub_tree is None:
+            raise Exception("One of leaf or sub tree must be set")
+
+        self.leaf_value: Union[V, Literal[Sentinel.sentinel]] = leaf_value
+        self.sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = sub_tree
+
+    @staticmethod
+    def leaf(value: V) -> "TreeCacheNode[V]":
+        return TreeCacheNode(leaf_value=value)
+
+    @staticmethod
+    def empty_branch() -> "TreeCacheNode[V]":
+        return TreeCacheNode(sub_tree={})
 
-class TreeCache:
+
+class TreeCache(Generic[V]):
     """
     Tree-based backing store for LruCache. Allows subtrees of data to be deleted
     efficiently.
@@ -35,15 +76,15 @@ class TreeCache:
 
     def __init__(self) -> None:
         self.size: int = 0
-        self.root = TreeCacheNode()
+        self.root: TreeCacheNode[V] = TreeCacheNode.empty_branch()
 
-    def __setitem__(self, key, value) -> None:
+    def __setitem__(self, key: tuple, value: V) -> None:
         self.set(key, value)
 
-    def __contains__(self, key) -> bool:
-        return self.get(key, SENTINEL) is not SENTINEL
+    def __contains__(self, key: tuple) -> bool:
+        return self.get(key, None) is not None
 
-    def set(self, key, value) -> None:
+    def set(self, key: tuple, value: V) -> None:
         if isinstance(value, TreeCacheNode):
             # this would mean we couldn't tell where our tree ended and the value
             # started.
@@ -51,31 +92,56 @@ class TreeCache:
 
         node = self.root
         for k in key[:-1]:
-            next_node = node.get(k, SENTINEL)
-            if next_node is SENTINEL:
-                next_node = node[k] = TreeCacheNode()
-            elif not isinstance(next_node, TreeCacheNode):
-                # this suggests that the caller is not being consistent with its key
-                # length.
+            sub_tree = node.sub_tree
+            if sub_tree is None:
                 raise ValueError("value conflicts with an existing subtree")
-            node = next_node
 
-        node[key[-1]] = value
+            next_node = sub_tree.get(k, None)
+            if next_node is None:
+                node = TreeCacheNode.empty_branch()
+                sub_tree[k] = node
+            else:
+                node = next_node
+
+        if node.sub_tree is None:
+            raise ValueError("value conflicts with an existing subtree")
+
+        node.sub_tree[key[-1]] = TreeCacheNode.leaf(value)
         self.size += 1
 
-    def get(self, key, default=None):
+    @overload
+    def get(self, key: tuple, default: Literal[None] = None) -> Union[None, V]:
+        ...
+
+    @overload
+    def get(self, key: tuple, default: T) -> Union[T, V]:
+        ...
+
+    def get(self, key: tuple, default: Optional[T] = None) -> Union[None, T, V]:
         node = self.root
-        for k in key[:-1]:
-            node = node.get(k, None)
-            if node is None:
+        for k in key:
+            sub_tree = node.sub_tree
+            if sub_tree is None:
+                raise ValueError("get() key too long")
+
+            next_node = sub_tree.get(k, None)
+            if next_node is None:
                 return default
-        return node.get(key[-1], default)
+
+            node = next_node
+
+        if node.leaf_value is Sentinel.sentinel:
+            raise ValueError("key points to a branch")
+
+        return node.leaf_value
 
     def clear(self) -> None:
         self.size = 0
         self.root = TreeCacheNode()
 
-    def pop(self, key, default=None):
+    def pop(
+        self, key: tuple, default: Optional[T] = None
+    ) -> Union[None, T, V, TreeCacheNode[V]]:
         """Remove the given key, or subkey, from the cache
 
         Args:
@@ -91,20 +157,25 @@ class TreeCache:
             raise TypeError("The cache key must be a tuple not %r" % (type(key),))
 
         # a list of the nodes we have touched on the way down the tree
-        nodes = []
+        nodes: List[TreeCacheNode[V]] = []
 
         node = self.root
         for k in key[:-1]:
-            node = node.get(k, None)
-            if node is None:
-                return default
-            if not isinstance(node, TreeCacheNode):
-                # we've gone off the end of the tree
+            sub_tree = node.sub_tree
+            if sub_tree is None:
                 raise ValueError("pop() key too long")
-            nodes.append(node)  # don't add the root node
-        popped = node.pop(key[-1], SENTINEL)
-        if popped is SENTINEL:
-            return default
+
+            next_node = sub_tree.get(k, None)
+            if next_node is None:
+                return default
+
+            node = next_node
+            nodes.append(node)
+
+        if node.sub_tree is None:
+            raise ValueError("pop() key too long")
+
+        popped = node.sub_tree.pop(key[-1])
 
         # working back up the tree, clear out any nodes that are now empty
         node_and_keys = list(zip(nodes, key))
@@ -116,8 +187,13 @@ class TreeCache:
 
             if n:
                 break
+
             # found an empty node: remove it from its parent, and loop.
-            node_and_keys[i + 1][0].pop(k)
+            node = node_and_keys[i + 1][0]
+
+            # We added it to the list so already know its a branch node.
+            assert node.sub_tree is not None
+            node.sub_tree.pop(k)
 
         cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
         self.size -= cnt
@@ -130,26 +206,31 @@ class TreeCache:
         return self.size
 
 
-def iterate_tree_cache_entry(d):
+def iterate_tree_cache_entry(d: TreeCacheNode[V]) -> Generator[V, None, None]:
     """Helper function to iterate over the leaves of a tree, i.e. a dict of that
     can contain dicts.
     """
-    if isinstance(d, TreeCacheNode):
-        for value_d in d.values():
+
+    if d.sub_tree is not None:
+        for value_d in d.sub_tree.values():
             yield from iterate_tree_cache_entry(value_d)
     else:
-        yield d
+        assert d.leaf_value is not Sentinel.sentinel
+        yield d.leaf_value
 
 
-def iterate_tree_cache_items(key, value):
+def iterate_tree_cache_items(
+    key: tuple, value: TreeCacheNode[V]
+) -> Generator[Tuple[tuple, V], None, None]:
     """Helper function to iterate over the leaves of a tree, i.e. a dict of that
     can contain dicts.
 
     Returns:
         A generator yielding key/value pairs.
     """
-    if isinstance(value, TreeCacheNode):
-        for sub_key, sub_value in value.items():
+    if value.sub_tree is not None:
+        for sub_key, sub_value in value.sub_tree.items():
             yield from iterate_tree_cache_items((*key, sub_key), sub_value)
     else:
-        yield key, value
+        assert value.leaf_value is not Sentinel.sentinel
+        yield key, value.leaf_value