summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/deferred_cache.py42
-rw-r--r--synapse/util/caches/descriptors.py8
-rw-r--r--synapse/util/caches/lrucache.py18
-rw-r--r--synapse/util/caches/treecache.py3
4 files changed, 36 insertions, 35 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
 
 import enum
 import threading
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    MutableMapping,
-    Optional,
-    TypeVar,
-    Union,
-    cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
 
 from prometheus_client import Gauge
 
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache = (
             cache_type()
-        )  # type: MutableMapping[KT, CacheEntry]
+        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
+        """Delete a key, or tree of entries
+
+        If the cache is backed by a regular dict, then "key" must be of
+        the right type for this cache
+
+        If the cache is backed by a TreeCache, then "key" must be a tuple, but
+        may be of lower cardinality than the TreeCache - in which case the whole
+        subtree is deleted.
+        """
         self.check_thread()
-        self.cache.pop(key, None)
+        self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
         # run the invalidation callbacks now, rather than waiting for the
         # deferred to resolve.
         if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key: KT):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        key = cast(KT, key)
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
+            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+            # case of a TreeCache, a dict of keys to cache entries. Either way calling
+            # iterate_tree_cache_entry on it will do the right thing.
+            for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
     def invalidate_all(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2ac24a2f25..d77e8edeea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
 class _CachedFunction(Generic[F]):
     invalidate = None  # type: Any
     invalidate_all = None  # type: Any
-    invalidate_many = None  # type: Any
     prefill = None  # type: Any
     cache = None  # type: Any
     num_args = None  # type: Any
@@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
+        if tree and self.num_args < 2:
+            raise RuntimeError(
+                "tree=True is nonsensical for cached functions with a single parameter"
+            )
+
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
@@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped = cast(_CachedFunction, _wrapped)
 
         if self.num_args == 1:
+            assert not self.tree
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 54df407ff7..d89e9d9b1d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
-    Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
     """
 
@@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_del_multi(key: KT) -> None:
+            """Delete an entry, or tree of entries
+
+            If the LruCache is backed by a regular dict, then "key" must be of
+            the right type for this cache
+
+            If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
+            may be of lower cardinality than the TreeCache - in which case the whole
+            subtree is deleted.
             """
-            This will only work if constructed with cache_type=TreeCache
-            """
-            popped = cache.pop(key)
+            popped = cache.pop(key, None)
             if popped is None:
                 return
             # for each deleted node, we now need to remove it from the linked list
@@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        self.del_multi = cache_del_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
-        self.invalidate = cache_pop
-        if cache_type is TreeCache:
-            self.del_multi = cache_del_multi
+        self.invalidate = cache_del_multi
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 73502a8b06..a6df81ebff 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -89,6 +89,9 @@ class TreeCache:
             value. If the key is partial, the TreeCacheNode corresponding to the part
             of the tree that was removed.
         """
+        if not isinstance(key, tuple):
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+
         # a list of the nodes we have touched on the way down the tree
         nodes = []