diff options
-rw-r--r-- | synapse/util/caches/lrucache.py | 2 | ||||
-rw-r--r-- | synapse/util/caches/treecache.py | 167 |
2 files changed, 125 insertions, 44 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 6f95c1354e..80671daca3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -417,7 +417,7 @@ class LruCache(Generic[KT, VT]): else: real_clock = clock - cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() + cache: Union[Dict[KT, _Node[KT, VT]], TreeCache[_Node[KT, VT]]] = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c0136ea269..37dcc040f4 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -12,18 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. -SENTINEL = object() +from enum import Enum +from typing import ( + Any, + Dict, + Generator, + Generic, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, + overload, +) -class TreeCacheNode(dict): +class Sentinel(Enum): + sentinel = object() + + +V = TypeVar("V") +T = TypeVar("T") + + +class TreeCacheNode(Generic[V]): """The type of nodes in our tree. - Has its own type so we can distinguish it from real dicts that are stored at the - leaves. + Either a leaf node or a branch node. """ + __slots__ = ["leaf_value", "sub_tree"] + + def __init__( + self, + leaf_value: Union[V, Literal[Sentinel.sentinel]] = Sentinel.sentinel, + sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = None, + ) -> None: + if leaf_value is Sentinel.sentinel and sub_tree is None: + raise Exception("One of leaf or sub tree must be set") + + self.leaf_value: Union[V, Literal[Sentinel.sentinel]] = leaf_value + self.sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = sub_tree + + @staticmethod + def leaf(value: V) -> "TreeCacheNode[V]": + return TreeCacheNode(leaf_value=value) + + @staticmethod + def empty_branch() -> "TreeCacheNode[V]": + return TreeCacheNode(sub_tree={}) -class TreeCache: + +class TreeCache(Generic[V]): """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. @@ -35,15 +76,15 @@ class TreeCache: def __init__(self) -> None: self.size: int = 0 - self.root = TreeCacheNode() + self.root: TreeCacheNode[V] = TreeCacheNode.empty_branch() - def __setitem__(self, key, value) -> None: + def __setitem__(self, key: tuple, value: V) -> None: self.set(key, value) - def __contains__(self, key) -> bool: - return self.get(key, SENTINEL) is not SENTINEL + def __contains__(self, key: tuple) -> bool: + return self.get(key, None) is not None - def set(self, key, value) -> None: + def set(self, key: tuple, value: V) -> None: if isinstance(value, TreeCacheNode): # this would mean we couldn't tell where our tree ended and the value # started. @@ -51,31 +92,56 @@ class TreeCache: node = self.root for k in key[:-1]: - next_node = node.get(k, SENTINEL) - if next_node is SENTINEL: - next_node = node[k] = TreeCacheNode() - elif not isinstance(next_node, TreeCacheNode): - # this suggests that the caller is not being consistent with its key - # length. + sub_tree = node.sub_tree + if sub_tree is None: raise ValueError("value conflicts with an existing subtree") - node = next_node - node[key[-1]] = value + next_node = sub_tree.get(k, None) + if next_node is None: + node = TreeCacheNode.empty_branch() + sub_tree[k] = node + else: + node = next_node + + if node.sub_tree is None: + raise ValueError("value conflicts with an existing subtree") + + node.sub_tree[key[-1]] = TreeCacheNode.leaf(value) self.size += 1 - def get(self, key, default=None): + @overload + def get(self, key: tuple, default: Literal[None] = None) -> Union[None, V]: + ... + + @overload + def get(self, key: tuple, default: T) -> Union[T, V]: + ... + + def get(self, key: tuple, default: Optional[T] = None) -> Union[None, T, V]: node = self.root - for k in key[:-1]: - node = node.get(k, None) - if node is None: + for k in key: + sub_tree = node.sub_tree + if sub_tree is None: + raise ValueError("get() key too long") + + next_node = sub_tree.get(k, None) + if next_node is None: return default - return node.get(key[-1], default) + + node = next_node + + if node.leaf_value is Sentinel.sentinel: + raise ValueError("key points to a branch") + + return node.leaf_value def clear(self) -> None: self.size = 0 self.root = TreeCacheNode() - def pop(self, key, default=None): + def pop( + self, key: tuple, default: Optional[T] = None + ) -> Union[None, T, V, TreeCacheNode[V]]: """Remove the given key, or subkey, from the cache Args: @@ -91,20 +157,25 @@ class TreeCache: raise TypeError("The cache key must be a tuple not %r" % (type(key),)) # a list of the nodes we have touched on the way down the tree - nodes = [] + nodes: List[TreeCacheNode[V]] = [] node = self.root for k in key[:-1]: - node = node.get(k, None) - if node is None: - return default - if not isinstance(node, TreeCacheNode): - # we've gone off the end of the tree + sub_tree = node.sub_tree + if sub_tree is None: raise ValueError("pop() key too long") - nodes.append(node) # don't add the root node - popped = node.pop(key[-1], SENTINEL) - if popped is SENTINEL: - return default + + next_node = sub_tree.get(k, None) + if next_node is None: + return default + + node = next_node + nodes.append(node) + + if node.sub_tree is None: + raise ValueError("pop() key too long") + + popped = node.sub_tree.pop(key[-1]) # working back up the tree, clear out any nodes that are now empty node_and_keys = list(zip(nodes, key)) @@ -116,8 +187,13 @@ class TreeCache: if n: break + # found an empty node: remove it from its parent, and loop. - node_and_keys[i + 1][0].pop(k) + node = node_and_keys[i + 1][0] + + # We added it to the list so already know its a branch node. + assert node.sub_tree is not None + node.sub_tree.pop(k) cnt = sum(1 for _ in iterate_tree_cache_entry(popped)) self.size -= cnt @@ -130,26 +206,31 @@ class TreeCache: return self.size -def iterate_tree_cache_entry(d): +def iterate_tree_cache_entry(d: TreeCacheNode[V]) -> Generator[V, None, None]: """Helper function to iterate over the leaves of a tree, i.e. a dict of that can contain dicts. """ - if isinstance(d, TreeCacheNode): - for value_d in d.values(): + + if d.sub_tree is not None: + for value_d in d.sub_tree.values(): yield from iterate_tree_cache_entry(value_d) else: - yield d + assert d.leaf_value is not Sentinel.sentinel + yield d.leaf_value -def iterate_tree_cache_items(key, value): +def iterate_tree_cache_items( + key: tuple, value: TreeCacheNode[V] +) -> Generator[Tuple[tuple, V], None, None]: """Helper function to iterate over the leaves of a tree, i.e. a dict of that can contain dicts. Returns: A generator yielding key/value pairs. """ - if isinstance(value, TreeCacheNode): - for sub_key, sub_value in value.items(): + if value.sub_tree is not None: + for sub_key, sub_value in value.sub_tree.items(): yield from iterate_tree_cache_items((*key, sub_key), sub_value) else: - yield key, value + assert value.leaf_value is not Sentinel.sentinel + yield key, value.leaf_value |