diff --git a/changelog.d/16905.misc b/changelog.d/16905.misc
new file mode 100644
index 0000000000..c5f47eb3e9
--- /dev/null
+++ b/changelog.d/16905.misc
@@ -0,0 +1 @@
+Don't invalidate the entire event cache when we purge history.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 7314d87404..bfd492d95d 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -373,7 +373,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
deleted.
"""
- self._invalidate_local_get_event_cache_all() # type: ignore[attr-defined]
+ self._invalidate_local_get_event_cache_room_id(room_id) # type: ignore[attr-defined]
self._attempt_to_invalidate_cache("have_seen_event", (room_id,))
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 1fd458b510..9c3775bb7c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -268,6 +268,8 @@ class EventsWorkerStore(SQLBaseStore):
] = AsyncLruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
+ # `extra_index_cb` Returns a tuple as that is the key type
+ extra_index_cb=lambda _, v: (v.event.room_id,),
)
# Map from event ID to a deferred that will result in a map from event
@@ -782,9 +784,9 @@ class EventsWorkerStore(SQLBaseStore):
if missing_events_ids:
- async def get_missing_events_from_cache_or_db() -> Dict[
- str, EventCacheEntry
- ]:
+ async def get_missing_events_from_cache_or_db() -> (
+ Dict[str, EventCacheEntry]
+ ):
"""Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow
@@ -910,12 +912,12 @@ class EventsWorkerStore(SQLBaseStore):
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
- def _invalidate_local_get_event_cache_all(self) -> None:
- """Clears the in-memory get event caches.
+ def _invalidate_local_get_event_cache_room_id(self, room_id: str) -> None:
+ """Clears the in-memory get event caches for a room.
Used when we purge room history.
"""
- self._get_event_cache.clear()
+ self._get_event_cache.invalidate_on_extra_index_local((room_id,))
self._event_ref.clear()
self._current_event_fetches.clear()
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6e8c1e84ac..a1b4f5b6a7 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -35,6 +35,7 @@ from typing import (
Iterable,
List,
Optional,
+ Set,
Tuple,
Type,
TypeVar,
@@ -386,6 +387,7 @@ class LruCache(Generic[KT, VT]):
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
+ extra_index_cb: Optional[Callable[[KT, VT], KT]] = None,
):
"""
Args:
@@ -416,6 +418,20 @@ class LruCache(Generic[KT, VT]):
prune_unread_entries: If True, cache entries that haven't been read recently
will be evicted from the cache in the background. Set to False to
opt-out of this behaviour.
+
+ extra_index_cb: If provided, the cache keeps a second index from a
+ (different) key to a cache entry based on the return value of
+ the callback. This can then be used to invalidate entries based
+ on the second type of key.
+
+ For example, for the event cache this would be a callback that
+ maps an event to its room ID, allowing invalidation of all
+ events in a given room.
+
+ Note: Though the two types of key have the same type, they are
+ in different namespaces.
+
+ Note: The new key does not have to be unique.
"""
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
@@ -463,6 +479,8 @@ class LruCache(Generic[KT, VT]):
lock = threading.Lock()
+ extra_index: Dict[KT, Set[KT]] = {}
+
def evict() -> None:
while cache_len() > self.max_size:
# Get the last node in the list (i.e. the oldest node).
@@ -521,6 +539,11 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] += size_callback(node.value)
+ if extra_index_cb:
+ index_key = extra_index_cb(node.key, node.value)
+ mapped_keys = extra_index.setdefault(index_key, set())
+ mapped_keys.add(node.key)
+
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
@@ -537,6 +560,14 @@ class LruCache(Generic[KT, VT]):
node.run_and_clear_callbacks()
+ if extra_index_cb:
+ index_key = extra_index_cb(node.key, node.value)
+ mapped_keys = extra_index.get(index_key)
+ if mapped_keys is not None:
+ mapped_keys.discard(node.key)
+ if not mapped_keys:
+ extra_index.pop(index_key, None)
+
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.dec_memory_usage(node.memory)
@@ -748,6 +779,8 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] = 0
+ extra_index.clear()
+
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.clear_memory_usage()
@@ -755,6 +788,28 @@ class LruCache(Generic[KT, VT]):
def cache_contains(key: KT) -> bool:
return key in cache
+ @synchronized
+ def cache_invalidate_on_extra_index(index_key: KT) -> None:
+ """Invalidates all entries that match the given extra index key.
+
+ This can only be called when `extra_index_cb` was specified.
+ """
+
+ assert extra_index_cb is not None
+
+ keys = extra_index.pop(index_key, None)
+ if not keys:
+ return
+
+ for key in keys:
+ node = cache.pop(key, None)
+ if not node:
+ continue
+
+ evicted_len = delete_node(node)
+ if metrics:
+ metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
+
# make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
@@ -771,6 +826,7 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear
+ self.invalidate_on_extra_index = cache_invalidate_on_extra_index
def __getitem__(self, key: KT) -> VT:
result = self.get(key, _Sentinel.sentinel)
@@ -864,6 +920,9 @@ class AsyncLruCache(Generic[KT, VT]):
# This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key)
+ def invalidate_on_extra_index_local(self, index_key: KT) -> None:
+ self._lru_cache.invalidate_on_extra_index(index_key)
+
def invalidate_local(self, key: KT) -> None:
"""Remove an entry from the local cache
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dcc2b4be89..3f0d8139f8 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -383,3 +383,34 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
# the items should still be in the cache
self.assertEqual(cache.get("key1"), 1)
self.assertEqual(cache.get("key2"), 2)
+
+
+class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
+ def test_invalidate_simple(self) -> None:
+ cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+ cache["key1"] = 1
+ cache["key2"] = 2
+
+ cache.invalidate_on_extra_index("key1")
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 2)
+
+ cache.invalidate_on_extra_index("1")
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), 2)
+
+ def test_invalidate_multi(self) -> None:
+ cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+ cache["key1"] = 1
+ cache["key2"] = 1
+ cache["key3"] = 2
+
+ cache.invalidate_on_extra_index("key1")
+ self.assertEqual(cache.get("key1"), 1)
+ self.assertEqual(cache.get("key2"), 1)
+ self.assertEqual(cache.get("key3"), 2)
+
+ cache.invalidate_on_extra_index("1")
+ self.assertEqual(cache.get("key1"), None)
+ self.assertEqual(cache.get("key2"), None)
+ self.assertEqual(cache.get("key3"), 2)
|