summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16905.misc1
-rw-r--r--synapse/storage/databases/main/cache.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py14
-rw-r--r--synapse/util/caches/lrucache.py59
-rw-r--r--tests/util/test_lrucache.py31
5 files changed, 100 insertions, 7 deletions
diff --git a/changelog.d/16905.misc b/changelog.d/16905.misc
new file mode 100644
index 0000000000..c5f47eb3e9
--- /dev/null
+++ b/changelog.d/16905.misc
@@ -0,0 +1 @@
+Don't invalidate the entire event cache when we purge history.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 7314d87404..bfd492d95d 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -373,7 +373,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         deleted.
         """
 
-        self._invalidate_local_get_event_cache_all()  # type: ignore[attr-defined]
+        self._invalidate_local_get_event_cache_room_id(room_id)  # type: ignore[attr-defined]
 
         self._attempt_to_invalidate_cache("have_seen_event", (room_id,))
         self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 1fd458b510..9c3775bb7c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -268,6 +268,8 @@ class EventsWorkerStore(SQLBaseStore):
         ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
+            # `extra_index_cb` Returns a tuple as that is the key type
+            extra_index_cb=lambda _, v: (v.event.room_id,),
         )
 
         # Map from event ID to a deferred that will result in a map from event
@@ -782,9 +784,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         if missing_events_ids:
 
-            async def get_missing_events_from_cache_or_db() -> Dict[
-                str, EventCacheEntry
-            ]:
+            async def get_missing_events_from_cache_or_db() -> (
+                Dict[str, EventCacheEntry]
+            ):
                 """Fetches the events in `missing_event_ids` from the database.
 
                 Also creates entries in `self._current_event_fetches` to allow
@@ -910,12 +912,12 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _invalidate_local_get_event_cache_all(self) -> None:
-        """Clears the in-memory get event caches.
+    def _invalidate_local_get_event_cache_room_id(self, room_id: str) -> None:
+        """Clears the in-memory get event caches for a room.
 
         Used when we purge room history.
         """
-        self._get_event_cache.clear()
+        self._get_event_cache.invalidate_on_extra_index_local((room_id,))
         self._event_ref.clear()
         self._current_event_fetches.clear()
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6e8c1e84ac..a1b4f5b6a7 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -35,6 +35,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Type,
     TypeVar,
@@ -386,6 +387,7 @@ class LruCache(Generic[KT, VT]):
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
         prune_unread_entries: bool = True,
+        extra_index_cb: Optional[Callable[[KT, VT], KT]] = None,
     ):
         """
         Args:
@@ -416,6 +418,20 @@ class LruCache(Generic[KT, VT]):
             prune_unread_entries: If True, cache entries that haven't been read recently
                 will be evicted from the cache in the background. Set to False to
                 opt-out of this behaviour.
+
+            extra_index_cb: If provided, the cache keeps a second index from a
+                (different) key to a cache entry based on the return value of
+                the callback. This can then be used to invalidate entries based
+                on the second type of key.
+
+                For example, for the event cache this would be a callback that
+                maps an event to its room ID, allowing invalidation of all
+                events in a given room.
+
+                Note: Though the two types of key have the same type, they are
+                in different namespaces.
+
+                Note: The new key does not have to be unique.
         """
         # Default `clock` to something sensible. Note that we rename it to
         # `real_clock` so that mypy doesn't think its still `Optional`.
@@ -463,6 +479,8 @@ class LruCache(Generic[KT, VT]):
 
         lock = threading.Lock()
 
+        extra_index: Dict[KT, Set[KT]] = {}
+
         def evict() -> None:
             while cache_len() > self.max_size:
                 # Get the last node in the list (i.e. the oldest node).
@@ -521,6 +539,11 @@ class LruCache(Generic[KT, VT]):
             if size_callback:
                 cached_cache_len[0] += size_callback(node.value)
 
+            if extra_index_cb:
+                index_key = extra_index_cb(node.key, node.value)
+                mapped_keys = extra_index.setdefault(index_key, set())
+                mapped_keys.add(node.key)
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
@@ -537,6 +560,14 @@ class LruCache(Generic[KT, VT]):
 
             node.run_and_clear_callbacks()
 
+            if extra_index_cb:
+                index_key = extra_index_cb(node.key, node.value)
+                mapped_keys = extra_index.get(index_key)
+                if mapped_keys is not None:
+                    mapped_keys.discard(node.key)
+                    if not mapped_keys:
+                        extra_index.pop(index_key, None)
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.dec_memory_usage(node.memory)
 
@@ -748,6 +779,8 @@ class LruCache(Generic[KT, VT]):
             if size_callback:
                 cached_cache_len[0] = 0
 
+            extra_index.clear()
+
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.clear_memory_usage()
 
@@ -755,6 +788,28 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
+        @synchronized
+        def cache_invalidate_on_extra_index(index_key: KT) -> None:
+            """Invalidates all entries that match the given extra index key.
+
+            This can only be called when `extra_index_cb` was specified.
+            """
+
+            assert extra_index_cb is not None
+
+            keys = extra_index.pop(index_key, None)
+            if not keys:
+                return
+
+            for key in keys:
+                node = cache.pop(key, None)
+                if not node:
+                    continue
+
+                evicted_len = delete_node(node)
+                if metrics:
+                    metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
+
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -771,6 +826,7 @@ class LruCache(Generic[KT, VT]):
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
+        self.invalidate_on_extra_index = cache_invalidate_on_extra_index
 
     def __getitem__(self, key: KT) -> VT:
         result = self.get(key, _Sentinel.sentinel)
@@ -864,6 +920,9 @@ class AsyncLruCache(Generic[KT, VT]):
         # This method should invalidate any external cache and then invalidate the LruCache.
         return self._lru_cache.invalidate(key)
 
+    def invalidate_on_extra_index_local(self, index_key: KT) -> None:
+        self._lru_cache.invalidate_on_extra_index(index_key)
+
     def invalidate_local(self, key: KT) -> None:
         """Remove an entry from the local cache
 
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dcc2b4be89..3f0d8139f8 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -383,3 +383,34 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
         # the items should still be in the cache
         self.assertEqual(cache.get("key1"), 1)
         self.assertEqual(cache.get("key2"), 2)
+
+
+class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
+    def test_invalidate_simple(self) -> None:
+        cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+        cache["key1"] = 1
+        cache["key2"] = 2
+
+        cache.invalidate_on_extra_index("key1")
+        self.assertEqual(cache.get("key1"), 1)
+        self.assertEqual(cache.get("key2"), 2)
+
+        cache.invalidate_on_extra_index("1")
+        self.assertEqual(cache.get("key1"), None)
+        self.assertEqual(cache.get("key2"), 2)
+
+    def test_invalidate_multi(self) -> None:
+        cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
+        cache["key1"] = 1
+        cache["key2"] = 1
+        cache["key3"] = 2
+
+        cache.invalidate_on_extra_index("key1")
+        self.assertEqual(cache.get("key1"), 1)
+        self.assertEqual(cache.get("key2"), 1)
+        self.assertEqual(cache.get("key3"), 2)
+
+        cache.invalidate_on_extra_index("1")
+        self.assertEqual(cache.get("key1"), None)
+        self.assertEqual(cache.get("key2"), None)
+        self.assertEqual(cache.get("key3"), 2)