diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index d89e9d9b1d..4b9d0433ff 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import threading
+import weakref
from functools import wraps
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -31,10 +34,19 @@ from typing import (
from typing_extensions import Literal
+from twisted.internet import reactor
+
from synapse.config import cache as cache_config
-from synapse.util import caches
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.linked_list import ListNode
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
try:
from pympler.asizeof import Asizer
@@ -82,19 +94,126 @@ def enumerate_leaves(node, depth):
yield m
+P = TypeVar("P")
+
+
+class _TimedListNode(ListNode[P]):
+ """A `ListNode` that tracks last access time."""
+
+ __slots__ = ["last_access_ts_secs"]
+
+ def update_last_access(self, clock: Clock):
+ self.last_access_ts_secs = int(clock.time())
+
+
+# Whether to insert new cache entries to the global list. We only add to it if
+# time based eviction is enabled.
+USE_GLOBAL_LIST = False
+
+# A linked list of all cache entries, allowing efficient time based eviction.
+GLOBAL_ROOT = ListNode["_Node"].create_root_node()
+
+
+@wrap_as_background_process("LruCache._expire_old_entries")
+async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+ """Walks the global cache list to find cache entries that haven't been
+ accessed in the given number of seconds.
+ """
+
+ now = int(clock.time())
+ node = GLOBAL_ROOT.prev_node
+ assert node is not None
+
+ i = 0
+
+ logger.debug("Searching for stale caches")
+
+ while node is not GLOBAL_ROOT:
+ # Only the root node isn't a `_TimedListNode`.
+ assert isinstance(node, _TimedListNode)
+
+ if node.last_access_ts_secs > now - expiry_seconds:
+ break
+
+ cache_entry = node.get_cache_entry()
+ next_node = node.prev_node
+
+ # The node should always have a reference to a cache entry and a valid
+ # `prev_node`, as we only drop them when we remove the node from the
+ # list.
+ assert next_node is not None
+ assert cache_entry is not None
+ cache_entry.drop_from_cache()
+
+ # If we do lots of work at once we yield to allow other stuff to happen.
+ if (i + 1) % 10000 == 0:
+ logger.debug("Waiting during drop")
+ await clock.sleep(0)
+ logger.debug("Waking during drop")
+
+ node = next_node
+
+ # If we've yielded then our current node may have been evicted, so we
+ # need to check that its still valid.
+ if node.prev_node is None:
+ break
+
+ i += 1
+
+ logger.info("Dropped %d items from caches", i)
+
+
+def setup_expire_lru_cache_entries(hs: "HomeServer"):
+ """Start a background job that expires all cache entries if they have not
+ been accessed for the given number of seconds.
+ """
+ if not hs.config.caches.expiry_time_msec:
+ return
+
+ logger.info(
+ "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
+ )
+
+ global USE_GLOBAL_LIST
+ USE_GLOBAL_LIST = True
+
+ clock = hs.get_clock()
+ clock.looping_call(
+ _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+ )
+
+
class _Node:
- __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
+ __slots__ = [
+ "_list_node",
+ "_global_list_node",
+ "_cache",
+ "key",
+ "value",
+ "callbacks",
+ "memory",
+ ]
def __init__(
self,
- prev_node,
- next_node,
+ root: "ListNode[_Node]",
key,
value,
+ cache: "weakref.ReferenceType[LruCache]",
+ clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
):
- self.prev_node = prev_node
- self.next_node = next_node
+ self._list_node = ListNode.insert_after(self, root)
+ self._global_list_node = None
+ if USE_GLOBAL_LIST:
+ self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
+ # We store a weak reference to the cache object so that this _Node can
+ # remove itself from the cache. If the cache is dropped we ensure we
+ # remove our entries in the lists.
+ self._cache = cache
+
self.key = key
self.value = value
@@ -116,11 +235,16 @@ class _Node:
self.memory = (
_get_size_of(key)
+ _get_size_of(value)
+ + _get_size_of(self._list_node, recurse=False)
+ _get_size_of(self.callbacks, recurse=False)
+ _get_size_of(self, recurse=False)
)
self.memory += _get_size_of(self.memory, recurse=False)
+ if self._global_list_node:
+ self.memory += _get_size_of(self._global_list_node, recurse=False)
+ self.memory += _get_size_of(self._global_list_node.last_access_ts_secs)
+
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
@@ -147,6 +271,32 @@ class _Node:
self.callbacks = None
+ def drop_from_cache(self) -> None:
+ """Drop this node from the cache.
+
+ Ensures that the entry gets removed from the cache and that we get
+ removed from all lists.
+ """
+ cache = self._cache()
+ if not cache or not cache.pop(self.key, None):
+ # `cache.pop` should call `drop_from_lists()`, unless this Node had
+ # already been removed from the cache.
+ self.drop_from_lists()
+
+ def drop_from_lists(self) -> None:
+ """Remove this node from the cache lists."""
+ self._list_node.remove_from_list()
+
+ if self._global_list_node:
+ self._global_list_node.remove_from_list()
+
+ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
+ """Moves this node to the front of all the lists its in."""
+ self._list_node.move_after(cache_list_root)
+ if self._global_list_node:
+ self._global_list_node.move_after(GLOBAL_ROOT)
+ self._global_list_node.update_last_access(clock)
+
class LruCache(Generic[KT, VT]):
"""
@@ -163,6 +313,7 @@ class LruCache(Generic[KT, VT]):
size_callback: Optional[Callable] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
+ clock: Optional[Clock] = None,
):
"""
Args:
@@ -188,6 +339,13 @@ class LruCache(Generic[KT, VT]):
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
"""
+ # Default `clock` to something sensible. Note that we rename it to
+ # `real_clock` so that mypy doesn't think its still `Optional`.
+ if clock is None:
+ real_clock = Clock(reactor)
+ else:
+ real_clock = clock
+
cache = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -219,17 +377,31 @@ class LruCache(Generic[KT, VT]):
# this is exposed for access from outside this class
self.metrics = metrics
- list_root = _Node(None, None, None, None)
- list_root.next_node = list_root
- list_root.prev_node = list_root
+ # We create a single weakref to self here so that we don't need to keep
+ # creating more each time we create a `_Node`.
+ weak_ref_to_self = weakref.ref(self)
+
+ list_root = ListNode[_Node].create_root_node()
lock = threading.Lock()
def evict():
while cache_len() > self.max_size:
+ # Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node
- evicted_len = delete_node(todelete)
- cache.pop(todelete.key, None)
+
+ # The list root should always have a valid `prev_node` if the
+ # cache is not empty.
+ assert todelete is not None
+
+ # The node should always have a reference to a cache entry, as
+ # we only drop the cache entry when we remove the node from the
+ # list.
+ node = todelete.get_cache_entry()
+ assert node is not None
+
+ evicted_len = delete_node(node)
+ cache.pop(node.key, None)
if metrics:
metrics.inc_evictions(evicted_len)
@@ -255,11 +427,7 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
- prev_node = list_root
- next_node = prev_node.next_node
- node = _Node(prev_node, next_node, key, value, callbacks)
- prev_node.next_node = node
- next_node.prev_node = node
+ node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
cache[key] = node
if size_callback:
@@ -268,23 +436,11 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
- prev_node = list_root
- next_node = prev_node.next_node
- node.prev_node = prev_node
- node.next_node = next_node
- prev_node.next_node = node
- next_node.prev_node = node
-
- def delete_node(node):
- prev_node = node.prev_node
- next_node = node.next_node
- prev_node.next_node = next_node
- next_node.prev_node = prev_node
+ def move_node_to_front(node: _Node):
+ node.move_to_front(real_clock, list_root)
+
+ def delete_node(node: _Node) -> int:
+ node.drop_from_lists()
deleted_len = 1
if size_callback:
@@ -411,10 +567,13 @@ class LruCache(Generic[KT, VT]):
@synchronized
def cache_clear() -> None:
- list_root.next_node = list_root
- list_root.prev_node = list_root
for node in cache.values():
node.run_and_clear_callbacks()
+ node.drop_from_lists()
+
+ assert list_root.next_node == list_root
+ assert list_root.prev_node == list_root
+
cache.clear()
if size_callback:
cached_cache_len[0] = 0
@@ -484,3 +643,11 @@ class LruCache(Generic[KT, VT]):
self._on_resize()
return True
return False
+
+ def __del__(self) -> None:
+ # We're about to be deleted, so we make sure to clear up all the nodes
+ # and run callbacks, etc.
+ #
+ # This happens e.g. in the sync code where we have an expiring cache of
+ # lru caches.
+ self.clear()
|