summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6e8c1e84ac..a1b4f5b6a7 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -35,6 +35,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Type,
     TypeVar,
@@ -386,6 +387,7 @@ class LruCache(Generic[KT, VT]):
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
         prune_unread_entries: bool = True,
+        extra_index_cb: Optional[Callable[[KT, VT], KT]] = None,
     ):
         """
         Args:
@@ -416,6 +418,20 @@ class LruCache(Generic[KT, VT]):
             prune_unread_entries: If True, cache entries that haven't been read recently
                 will be evicted from the cache in the background. Set to False to
                 opt-out of this behaviour.
+
+            extra_index_cb: If provided, the cache keeps a second index from a
+                (different) key to a cache entry based on the return value of
+                the callback. This can then be used to invalidate entries based
+                on the second type of key.
+
+                For example, for the event cache this would be a callback that
+                maps an event to its room ID, allowing invalidation of all
+                events in a given room.
+
+                Note: Though the two types of key have the same type, they are
+                in different namespaces.
+
+                Note: The new key does not have to be unique.
         """
         # Default `clock` to something sensible. Note that we rename it to
         # `real_clock` so that mypy doesn't think its still `Optional`.
@@ -463,6 +479,8 @@ class LruCache(Generic[KT, VT]):
 
         lock = threading.Lock()
 
+        extra_index: Dict[KT, Set[KT]] = {}
+
         def evict() -> None:
             while cache_len() > self.max_size:
                 # Get the last node in the list (i.e. the oldest node).
@@ -521,6 +539,11 @@ class LruCache(Generic[KT, VT]):
             if size_callback:
                 cached_cache_len[0] += size_callback(node.value)
 
+            if extra_index_cb:
+                index_key = extra_index_cb(node.key, node.value)
+                mapped_keys = extra_index.setdefault(index_key, set())
+                mapped_keys.add(node.key)
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
@@ -537,6 +560,14 @@ class LruCache(Generic[KT, VT]):
 
             node.run_and_clear_callbacks()
 
+            if extra_index_cb:
+                index_key = extra_index_cb(node.key, node.value)
+                mapped_keys = extra_index.get(index_key)
+                if mapped_keys is not None:
+                    mapped_keys.discard(node.key)
+                    if not mapped_keys:
+                        extra_index.pop(index_key, None)
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.dec_memory_usage(node.memory)
 
@@ -748,6 +779,8 @@ class LruCache(Generic[KT, VT]):
             if size_callback:
                 cached_cache_len[0] = 0
 
+            extra_index.clear()
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.clear_memory_usage()
 
@@ -755,6 +788,28 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
+        @synchronized
+        def cache_invalidate_on_extra_index(index_key: KT) -> None:
+            """Invalidates all entries that match the given extra index key.
+
+            This can only be called when `extra_index_cb` was specified.
+            """
+
+            assert extra_index_cb is not None
+
+            keys = extra_index.pop(index_key, None)
+            if not keys:
+                return
+
+            for key in keys:
+                node = cache.pop(key, None)
+                if not node:
+                    continue
+
+                evicted_len = delete_node(node)
+                if metrics:
+                    metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
+
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -771,6 +826,7 @@ class LruCache(Generic[KT, VT]):
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
+        self.invalidate_on_extra_index = cache_invalidate_on_extra_index
 
     def __getitem__(self, key: KT) -> VT:
         result = self.get(key, _Sentinel.sentinel)
@@ -864,6 +920,9 @@ class AsyncLruCache(Generic[KT, VT]):
         # This method should invalidate any external cache and then invalidate the LruCache.
         return self._lru_cache.invalidate(key)
 
+    def invalidate_on_extra_index_local(self, index_key: KT) -> None:
+        self._lru_cache.invalidate_on_extra_index(index_key)
+
     def invalidate_local(self, key: KT) -> None:
         """Remove an entry from the local cache