summary refs log tree commit diff
path: root/synapse/util/caches/deferred_cache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/deferred_cache.py')
-rw-r--r--synapse/util/caches/deferred_cache.py42
1 files changed, 16 insertions, 26 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
 
 import enum
 import threading
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    MutableMapping,
-    Optional,
-    TypeVar,
-    Union,
-    cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
 
 from prometheus_client import Gauge
 
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache = (
             cache_type()
-        )  # type: MutableMapping[KT, CacheEntry]
+        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
+        """Delete a key, or tree of entries
+
+        If the cache is backed by a regular dict, then "key" must be of
+        the right type for this cache
+
+        If the cache is backed by a TreeCache, then "key" must be a tuple, but
+        may be of lower cardinality than the TreeCache - in which case the whole
+        subtree is deleted.
+        """
         self.check_thread()
-        self.cache.pop(key, None)
+        self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
         # run the invalidation callbacks now, rather than waiting for the
         # deferred to resolve.
         if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key: KT):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        key = cast(KT, key)
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
+            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+            # case of a TreeCache, a dict of keys to cache entries. Either way calling
+            # iterate_tree_cache_entry on it will do the right thing.
+            for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
     def invalidate_all(self):