summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9973.misc1
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/storage/databases/main/cache.py6
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/util/caches/deferred_cache.py42
-rw-r--r--synapse/util/caches/descriptors.py8
-rw-r--r--synapse/util/caches/lrucache.py18
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--tests/util/caches/test_descriptors.py6
12 files changed, 52 insertions, 52 deletions
diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc
new file mode 100644
index 0000000000..7f22d42291
--- /dev/null
+++ b/changelog.d/9973.misc
@@ -0,0 +1 @@
+Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`.
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 70207420a6..26bdead565 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             if row.entity.startswith("@"):
                 self._device_list_stream_cache.entity_has_changed(row.entity, token)
                 self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate_many((row.entity,))
+                self._get_cached_user_device.invalidate((row.entity,))
                 self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
             else:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ecc1f935e2..f7872501a0 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+        self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_invited_rooms_for_local_user.invalidate((state_key,))
 
         if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_relations_for_event.invalidate((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate((relates_to,))
             self.get_applicable_edit.invalidate((relates_to,))
 
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fd87ba71ab..18f07d96dc 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
         txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
-        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5845322118..d1237c65cc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                                   not be deleted.
         """
         txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id, user_id),
         )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fd25c8112d..897fa06639 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1748,9 +1748,9 @@ class PersistEventsStore:
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
         )
 
         if rel_type == RelationTypes.REPLACE:
@@ -1903,7 +1903,7 @@ class PersistEventsStore:
 
                 for user_id in user_ids:
                     txn.call_after(
-                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
                         (room_id, user_id),
                     )
 
@@ -1917,7 +1917,7 @@ class PersistEventsStore:
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
-            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id,),
         )
         txn.execute(
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3647276acb..edeaacd7a6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
@@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
         txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
-        )
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
         self.db_pool.simple_delete_txn(
             txn,
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
 
 import enum
 import threading
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    MutableMapping,
-    Optional,
-    TypeVar,
-    Union,
-    cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
 
 from prometheus_client import Gauge
 
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache = (
             cache_type()
-        )  # type: MutableMapping[KT, CacheEntry]
+        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
+        """Delete a key, or tree of entries
+
+        If the cache is backed by a regular dict, then "key" must be of
+        the right type for this cache
+
+        If the cache is backed by a TreeCache, then "key" must be a tuple, but
+        may be of lower cardinality than the TreeCache - in which case the whole
+        subtree is deleted.
+        """
         self.check_thread()
-        self.cache.pop(key, None)
+        self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
         # run the invalidation callbacks now, rather than waiting for the
         # deferred to resolve.
         if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key: KT):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        key = cast(KT, key)
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
+            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+            # case of a TreeCache, a dict of keys to cache entries. Either way calling
+            # iterate_tree_cache_entry on it will do the right thing.
+            for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
     def invalidate_all(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2ac24a2f25..d77e8edeea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
 class _CachedFunction(Generic[F]):
     invalidate = None  # type: Any
     invalidate_all = None  # type: Any
-    invalidate_many = None  # type: Any
     prefill = None  # type: Any
     cache = None  # type: Any
     num_args = None  # type: Any
@@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
+        if tree and self.num_args < 2:
+            raise RuntimeError(
+                "tree=True is nonsensical for cached functions with a single parameter"
+            )
+
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
@@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped = cast(_CachedFunction, _wrapped)
 
         if self.num_args == 1:
+            assert not self.tree
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 54df407ff7..d89e9d9b1d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
-    Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
     """
 
@@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_del_multi(key: KT) -> None:
+            """Delete an entry, or tree of entries
+
+            If the LruCache is backed by a regular dict, then "key" must be of
+            the right type for this cache
+
+            If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
+            may be of lower cardinality than the TreeCache - in which case the whole
+            subtree is deleted.
             """
-            This will only work if constructed with cache_type=TreeCache
-            """
-            popped = cache.pop(key)
+            popped = cache.pop(key, None)
             if popped is None:
                 return
             # for each deleted node, we now need to remove it from the linked list
@@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        self.del_multi = cache_del_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
-        self.invalidate = cache_pop
-        if cache_type is TreeCache:
-            self.del_multi = cache_del_multi
+        self.invalidate = cache_del_multi
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 73502a8b06..a6df81ebff 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -89,6 +89,9 @@ class TreeCache:
             value. If the key is partial, the TreeCacheNode corresponding to the part
             of the tree that was removed.
         """
+        if not isinstance(key, tuple):
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+
         # a list of the nodes we have touched on the way down the tree
         nodes = []
 
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index bbbc276697..0277998cbe 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEquals(callcount2[0], 1)
 
         a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1)
 
         yield a.func2("foo")
         a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2)
 
         self.assertEquals(callcount[0], 1)
         self.assertEquals(callcount2[0], 2)
 
         a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3)
         yield a.func("foo")
 
         self.assertEquals(callcount[0], 2)