diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
import enum
import threading
-from typing import (
- Callable,
- Generic,
- Iterable,
- MutableMapping,
- Optional,
- TypeVar,
- Union,
- cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
from prometheus_client import Gauge
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
- ) # type: MutableMapping[KT, CacheEntry]
+ ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
+ """Delete a key, or tree of entries
+
+ If the cache is backed by a regular dict, then "key" must be of
+ the right type for this cache
+
+ If the cache is backed by a TreeCache, then "key" must be a tuple, but
+ may be of lower cardinality than the TreeCache - in which case the whole
+ subtree is deleted.
+ """
self.check_thread()
- self.cache.pop(key, None)
+ self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
- entry.invalidate()
-
- def invalidate_many(self, key: KT):
- self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError("The cache key must be a tuple not %r" % (type(key),))
- key = cast(KT, key)
- self.cache.del_multi(key)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(key, None)
- if entry_dict is not None:
- for entry in iterate_tree_cache_entry(entry_dict):
+ # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+ # case of a TreeCache, a dict of keys to cache entries. Either way calling
+ # iterate_tree_cache_entry on it will do the right thing.
+ for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
def invalidate_all(self):
|