diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6f95c1354e..80671daca3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -417,7 +417,7 @@ class LruCache(Generic[KT, VT]):
else:
real_clock = clock
- cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
+ cache: Union[Dict[KT, _Node[KT, VT]], TreeCache[_Node[KT, VT]]] = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c0136ea269..37dcc040f4 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -12,18 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SENTINEL = object()
+from enum import Enum
+from typing import (
+ Any,
+ Dict,
+ Generator,
+ Generic,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
-class TreeCacheNode(dict):
+class Sentinel(Enum):
+ sentinel = object()
+
+
+V = TypeVar("V")
+T = TypeVar("T")
+
+
+class TreeCacheNode(Generic[V]):
"""The type of nodes in our tree.
- Has its own type so we can distinguish it from real dicts that are stored at the
- leaves.
+ Either a leaf node or a branch node.
"""
+ __slots__ = ["leaf_value", "sub_tree"]
+
+ def __init__(
+ self,
+ leaf_value: Union[V, Literal[Sentinel.sentinel]] = Sentinel.sentinel,
+ sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = None,
+ ) -> None:
+ if leaf_value is Sentinel.sentinel and sub_tree is None:
+ raise Exception("One of leaf or sub tree must be set")
+
+ self.leaf_value: Union[V, Literal[Sentinel.sentinel]] = leaf_value
+ self.sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = sub_tree
+
+ @staticmethod
+ def leaf(value: V) -> "TreeCacheNode[V]":
+ return TreeCacheNode(leaf_value=value)
+
+ @staticmethod
+ def empty_branch() -> "TreeCacheNode[V]":
+ return TreeCacheNode(sub_tree={})
-class TreeCache:
+
+class TreeCache(Generic[V]):
"""
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
efficiently.
@@ -35,15 +76,15 @@ class TreeCache:
def __init__(self) -> None:
self.size: int = 0
- self.root = TreeCacheNode()
+ self.root: TreeCacheNode[V] = TreeCacheNode.empty_branch()
- def __setitem__(self, key, value) -> None:
+ def __setitem__(self, key: tuple, value: V) -> None:
self.set(key, value)
- def __contains__(self, key) -> bool:
- return self.get(key, SENTINEL) is not SENTINEL
+ def __contains__(self, key: tuple) -> bool:
+ return self.get(key, None) is not None
- def set(self, key, value) -> None:
+ def set(self, key: tuple, value: V) -> None:
if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value
# started.
@@ -51,31 +92,56 @@ class TreeCache:
node = self.root
for k in key[:-1]:
- next_node = node.get(k, SENTINEL)
- if next_node is SENTINEL:
- next_node = node[k] = TreeCacheNode()
- elif not isinstance(next_node, TreeCacheNode):
- # this suggests that the caller is not being consistent with its key
- # length.
+ sub_tree = node.sub_tree
+ if sub_tree is None:
raise ValueError("value conflicts with an existing subtree")
- node = next_node
- node[key[-1]] = value
+ next_node = sub_tree.get(k, None)
+ if next_node is None:
+ node = TreeCacheNode.empty_branch()
+ sub_tree[k] = node
+ else:
+ node = next_node
+
+ if node.sub_tree is None:
+ raise ValueError("value conflicts with an existing subtree")
+
+ node.sub_tree[key[-1]] = TreeCacheNode.leaf(value)
self.size += 1
- def get(self, key, default=None):
+ @overload
+ def get(self, key: tuple, default: Literal[None] = None) -> Union[None, V]:
+ ...
+
+ @overload
+ def get(self, key: tuple, default: T) -> Union[T, V]:
+ ...
+
+ def get(self, key: tuple, default: Optional[T] = None) -> Union[None, T, V]:
node = self.root
- for k in key[:-1]:
- node = node.get(k, None)
- if node is None:
+ for k in key:
+ sub_tree = node.sub_tree
+ if sub_tree is None:
+ raise ValueError("get() key too long")
+
+ next_node = sub_tree.get(k, None)
+ if next_node is None:
return default
- return node.get(key[-1], default)
+
+ node = next_node
+
+ if node.leaf_value is Sentinel.sentinel:
+ raise ValueError("key points to a branch")
+
+ return node.leaf_value
def clear(self) -> None:
self.size = 0
self.root = TreeCacheNode()
- def pop(self, key, default=None):
+ def pop(
+ self, key: tuple, default: Optional[T] = None
+ ) -> Union[None, T, V, TreeCacheNode[V]]:
"""Remove the given key, or subkey, from the cache
Args:
@@ -91,20 +157,25 @@ class TreeCache:
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
# a list of the nodes we have touched on the way down the tree
- nodes = []
+ nodes: List[TreeCacheNode[V]] = []
node = self.root
for k in key[:-1]:
- node = node.get(k, None)
- if node is None:
- return default
- if not isinstance(node, TreeCacheNode):
- # we've gone off the end of the tree
+ sub_tree = node.sub_tree
+ if sub_tree is None:
raise ValueError("pop() key too long")
- nodes.append(node) # don't add the root node
- popped = node.pop(key[-1], SENTINEL)
- if popped is SENTINEL:
- return default
+
+ next_node = sub_tree.get(k, None)
+ if next_node is None:
+ return default
+
+ node = next_node
+ nodes.append(node)
+
+ if node.sub_tree is None:
+ raise ValueError("pop() key too long")
+
+ popped = node.sub_tree.pop(key[-1])
# working back up the tree, clear out any nodes that are now empty
node_and_keys = list(zip(nodes, key))
@@ -116,8 +187,13 @@ class TreeCache:
if n:
break
+
# found an empty node: remove it from its parent, and loop.
- node_and_keys[i + 1][0].pop(k)
+ node = node_and_keys[i + 1][0]
+
+ # We added it to the list so already know its a branch node.
+ assert node.sub_tree is not None
+ node.sub_tree.pop(k)
cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
self.size -= cnt
@@ -130,26 +206,31 @@ class TreeCache:
return self.size
-def iterate_tree_cache_entry(d):
+def iterate_tree_cache_entry(d: TreeCacheNode[V]) -> Generator[V, None, None]:
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
- if isinstance(d, TreeCacheNode):
- for value_d in d.values():
+
+ if d.sub_tree is not None:
+ for value_d in d.sub_tree.values():
yield from iterate_tree_cache_entry(value_d)
else:
- yield d
+ assert d.leaf_value is not Sentinel.sentinel
+ yield d.leaf_value
-def iterate_tree_cache_items(key, value):
+def iterate_tree_cache_items(
+ key: tuple, value: TreeCacheNode[V]
+) -> Generator[Tuple[tuple, V], None, None]:
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
Returns:
A generator yielding key/value pairs.
"""
- if isinstance(value, TreeCacheNode):
- for sub_key, sub_value in value.items():
+ if value.sub_tree is not None:
+ for sub_key, sub_value in value.sub_tree.items():
yield from iterate_tree_cache_items((*key, sub_key), sub_value)
else:
- yield key, value
+ assert value.leaf_value is not Sentinel.sentinel
+ yield key, value.leaf_value
|