diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
import enum
import threading
-from typing import (
- Callable,
- Generic,
- Iterable,
- MutableMapping,
- Optional,
- TypeVar,
- Union,
- cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
from prometheus_client import Gauge
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
- ) # type: MutableMapping[KT, CacheEntry]
+ ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
+ """Delete a key, or tree of entries
+
+ If the cache is backed by a regular dict, then "key" must be of
+ the right type for this cache
+
+ If the cache is backed by a TreeCache, then "key" must be a tuple, but
+ may be of lower cardinality than the TreeCache - in which case the whole
+ subtree is deleted.
+ """
self.check_thread()
- self.cache.pop(key, None)
+ self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
- entry.invalidate()
-
- def invalidate_many(self, key: KT):
- self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError("The cache key must be a tuple not %r" % (type(key),))
- key = cast(KT, key)
- self.cache.del_multi(key)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(key, None)
- if entry_dict is not None:
- for entry in iterate_tree_cache_entry(entry_dict):
+ # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+ # case of a TreeCache, a dict of keys to cache entries. Either way calling
+ # iterate_tree_cache_entry on it will do the right thing.
+ for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
def invalidate_all(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2ac24a2f25..d77e8edeea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]):
invalidate = None # type: Any
invalidate_all = None # type: Any
- invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
@@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
+ if tree and self.num_args < 2:
+ raise RuntimeError(
+ "tree=True is nonsensical for cached functions with a single parameter"
+ )
+
self.max_entries = max_entries
self.tree = tree
self.iterable = iterable
@@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1:
+ assert not self.tree
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
- wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 54df407ff7..d89e9d9b1d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
- Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
"""
@@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
@synchronized
def cache_del_multi(key: KT) -> None:
+ """Delete an entry, or tree of entries
+
+ If the LruCache is backed by a regular dict, then "key" must be of
+ the right type for this cache
+
+ If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
+ may be of lower cardinality than the TreeCache - in which case the whole
+ subtree is deleted.
"""
- This will only work if constructed with cache_type=TreeCache
- """
- popped = cache.pop(key)
+ popped = cache.pop(key, None)
if popped is None:
return
# for each deleted node, we now need to remove it from the linked list
@@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ self.del_multi = cache_del_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
- self.invalidate = cache_pop
- if cache_type is TreeCache:
- self.del_multi = cache_del_multi
+ self.invalidate = cache_del_multi
self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 73502a8b06..a6df81ebff 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -89,6 +89,9 @@ class TreeCache:
value. If the key is partial, the TreeCacheNode corresponding to the part
of the tree that was removed.
"""
+ if not isinstance(key, tuple):
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+
# a list of the nodes we have touched on the way down the tree
nodes = []
|