summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/caches/descriptors.py1
-rw-r--r--synapse/util/caches/lrucache.py10
-rw-r--r--synapse/util/caches/treecache.py104
4 files changed, 70 insertions, 47 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 484097a48a..371e7e4dd0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -70,7 +70,6 @@ class DeferredCache(Generic[KT, VT]):
         self,
         name: str,
         max_entries: int = 1000,
-        keylen: int = 1,
         tree: bool = False,
         iterable: bool = False,
         apply_cache_factor_from_config: bool = True,
@@ -101,7 +100,6 @@ class DeferredCache(Generic[KT, VT]):
         # a Deferred.
         self.cache = LruCache(
             max_size=max_entries,
-            keylen=keylen,
             cache_name=name,
             cache_type=cache_type,
             size_callback=(lambda d: len(d) or 1) if iterable else None,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 3a4d027095..2ac24a2f25 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -270,7 +270,6 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
-            keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
         )  # type: DeferredCache[CacheKey, Any]
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1be675e014..54df407ff7 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -34,7 +34,7 @@ from typing_extensions import Literal
 from synapse.config import cache as cache_config
 from synapse.util import caches
 from synapse.util.caches import CacheMetric, register_cache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 
 try:
     from pympler.asizeof import Asizer
@@ -160,7 +160,6 @@ class LruCache(Generic[KT, VT]):
         self,
         max_size: int,
         cache_name: Optional[str] = None,
-        keylen: int = 1,
         cache_type: Type[Union[dict, TreeCache]] = dict,
         size_callback: Optional[Callable] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
@@ -173,9 +172,6 @@ class LruCache(Generic[KT, VT]):
             cache_name: The name of this cache, for the prometheus metrics. If unset,
                 no metrics will be reported on this cache.
 
-            keylen: The length of the tuple used as the cache key. Ignored unless
-                cache_type is `TreeCache`.
-
             cache_type (type):
                 type of underlying cache to be used. Typically one of dict
                 or TreeCache.
@@ -403,7 +399,9 @@ class LruCache(Generic[KT, VT]):
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
+            # for each deleted node, we now need to remove it from the linked list
+            # and run its callbacks.
+            for leaf in iterate_tree_cache_entry(popped):
                 delete_node(leaf)
 
         @synchronized
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index eb4d98f683..73502a8b06 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,18 +1,43 @@
-from typing import Dict
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 SENTINEL = object()
 
 
+class TreeCacheNode(dict):
+    """The type of nodes in our tree.
+
+    Has its own type so we can distinguish it from real dicts that are stored at the
+    leaves.
+    """
+
+    pass
+
+
 class TreeCache:
     """
     Tree-based backing store for LruCache. Allows subtrees of data to be deleted
     efficiently.
     Keys must be tuples.
+
+    The data structure is a chain of TreeCacheNodes:
+        root = {key_1: {key_2: _value}}
     """
 
     def __init__(self):
         self.size = 0
-        self.root = {}  # type: Dict
+        self.root = TreeCacheNode()
 
     def __setitem__(self, key, value):
         return self.set(key, value)
@@ -21,10 +46,23 @@ class TreeCache:
         return self.get(key, SENTINEL) is not SENTINEL
 
     def set(self, key, value):
+        if isinstance(value, TreeCacheNode):
+            # this would mean we couldn't tell where our tree ended and the value
+            # started.
+            raise ValueError("Cannot store TreeCacheNodes in a TreeCache")
+
         node = self.root
         for k in key[:-1]:
-            node = node.setdefault(k, {})
-        node[key[-1]] = _Entry(value)
+            next_node = node.get(k, SENTINEL)
+            if next_node is SENTINEL:
+                next_node = node[k] = TreeCacheNode()
+            elif not isinstance(next_node, TreeCacheNode):
+                # this suggests that the caller is not being consistent with its key
+                # length.
+                raise ValueError("value conflicts with an existing subtree")
+            node = next_node
+
+        node[key[-1]] = value
         self.size += 1
 
     def get(self, key, default=None):
@@ -33,25 +71,41 @@ class TreeCache:
             node = node.get(k, None)
             if node is None:
                 return default
-        return node.get(key[-1], _Entry(default)).value
+        return node.get(key[-1], default)
 
     def clear(self):
         self.size = 0
-        self.root = {}
+        self.root = TreeCacheNode()
 
     def pop(self, key, default=None):
+        """Remove the given key, or subkey, from the cache
+
+        Args:
+            key: key or subkey to remove.
+            default: value to return if key is not found
+
+        Returns:
+            If the key is not found, 'default'. If the key is complete, the removed
+            value. If the key is partial, the TreeCacheNode corresponding to the part
+            of the tree that was removed.
+        """
+        # a list of the nodes we have touched on the way down the tree
         nodes = []
 
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
-            nodes.append(node)  # don't add the root node
             if node is None:
                 return default
+            if not isinstance(node, TreeCacheNode):
+                # we've gone off the end of the tree
+                raise ValueError("pop() key too long")
+            nodes.append(node)  # don't add the root node
         popped = node.pop(key[-1], SENTINEL)
         if popped is SENTINEL:
             return default
 
+        # working back up the tree, clear out any nodes that are now empty
         node_and_keys = list(zip(nodes, key))
         node_and_keys.reverse()
         node_and_keys.append((self.root, None))
@@ -61,14 +115,15 @@ class TreeCache:
 
             if n:
                 break
+            # found an empty node: remove it from its parent, and loop.
             node_and_keys[i + 1][0].pop(k)
 
-        popped, cnt = _strip_and_count_entires(popped)
+        cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
         self.size -= cnt
         return popped
 
     def values(self):
-        return list(iterate_tree_cache_entry(self.root))
+        return iterate_tree_cache_entry(self.root)
 
     def __len__(self):
         return self.size
@@ -78,36 +133,9 @@ def iterate_tree_cache_entry(d):
     """Helper function to iterate over the leaves of a tree, i.e. a dict of that
     can contain dicts.
     """
-    if isinstance(d, dict):
+    if isinstance(d, TreeCacheNode):
         for value_d in d.values():
             for value in iterate_tree_cache_entry(value_d):
                 yield value
     else:
-        if isinstance(d, _Entry):
-            yield d.value
-        else:
-            yield d
-
-
-class _Entry:
-    __slots__ = ["value"]
-
-    def __init__(self, value):
-        self.value = value
-
-
-def _strip_and_count_entires(d):
-    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
-    value or a dictionary with _Entry's replaced by their values.
-
-    Also returns the count of _Entry's
-    """
-    if isinstance(d, dict):
-        cnt = 0
-        for key, value in d.items():
-            v, n = _strip_and_count_entires(value)
-            d[key] = v
-            cnt += n
-        return d, cnt
-    else:
-        return d.value, 1
+        yield d