summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-10-02 07:25:06 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-10-02 07:25:06 -0400
commit09825c277548e8fccad1f517e2e0a46a985c262d (patch)
treed6a6c2e6e978e0276a45dff2223e76fba4f98865
parentMerge remote-tracking branch 'origin/develop' into erikj/rust_lru_cache (diff)
downloadsynapse-clokep/erikj/rust_lru_cache.tar.xz
-rw-r--r--Cargo.lock21
-rw-r--r--rust/src/lru_cache.rs16
-rw-r--r--stubs/synapse/synapse_rust/lru_cache.pyi48
-rw-r--r--synapse/util/caches/lrucache.py173
-rw-r--r--synapse/util/linked_list.py150
5 files changed, 105 insertions, 303 deletions
diff --git a/Cargo.lock b/Cargo.lock
index ea9aa18a5c..30464a41f5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -103,6 +103,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3"
 
 [[package]]
+name = "intrusive-collections"
+version = "0.9.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b694dc9f70c3bda874626d2aed13b780f137aab435f4e9814121955cf706122e"
+dependencies = [
+ "memoffset 0.9.0",
+]
+
+[[package]]
 name = "itoa"
 version = "1.0.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -152,6 +161,15 @@ dependencies = [
 ]
 
 [[package]]
+name = "memoffset"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
 name = "once_cell"
 version = "1.15.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -199,7 +217,7 @@ dependencies = [
  "cfg-if",
  "indoc",
  "libc",
- "memoffset",
+ "memoffset 0.6.5",
  "parking_lot",
  "pyo3-build-config",
  "pyo3-ffi",
@@ -402,6 +420,7 @@ dependencies = [
  "anyhow",
  "blake2",
  "hex",
+ "intrusive-collections",
  "lazy_static",
  "log",
  "pyo3",
diff --git a/rust/src/lru_cache.rs b/rust/src/lru_cache.rs
index 847f5d84be..efdcef8622 100644
--- a/rust/src/lru_cache.rs
+++ b/rust/src/lru_cache.rs
@@ -51,6 +51,13 @@ struct LruCacheNodeInner {
     value: Arc<Mutex<PyObject>>,
     callbacks: Py<PySet>,
     memory: usize,
+    last_access_ts_secs: usize,
+}
+
+impl LruCacheNodeInner {
+    fn update_last_access(&mut self, ts_secs: usize) {
+        self.last_access_ts_secs = ts_secs;
+    }
 }
 
 #[pyclass]
@@ -66,6 +73,7 @@ impl LruCacheNode {
         value: PyObject,
         callbacks: Py<PySet>,
         memory: usize,
+        ts_secs: usize,
     ) -> Self {
         let node = Arc::new(LruCacheNodeInner {
             per_cache_link: Default::default(),
@@ -76,6 +84,7 @@ impl LruCacheNode {
             value: Arc::new(Mutex::new(value)),
             callbacks,
             memory,
+            last_access_ts_secs: ts_secs,
         });
 
         GLOBAL_LIST
@@ -159,7 +168,7 @@ impl LruCacheNode {
         }
     }
 
-    fn move_to_front(&self) {
+    fn move_to_front(&self, ts_secs: usize) {
         if self.0.global_list_link.is_linked() {
             let mut global_list = GLOBAL_LIST.lock().expect("poisoned");
 
@@ -171,6 +180,8 @@ impl LruCacheNode {
             curor_mut.remove();
 
             global_list.push_front(self.0.clone());
+
+            // TODO Update self.0.last_access_ts_secs
         }
 
         if self.0.per_cache_link.is_linked() {
@@ -207,6 +218,9 @@ impl LruCacheNode {
     fn memory(&self) -> usize {
         self.0.memory
     }
+
+    #[getter]
+    fn last_access_ts_secs(&self) -> usize { self.0.last_access_ts_secs }
 }
 
 #[pyfunction]
diff --git a/stubs/synapse/synapse_rust/lru_cache.pyi b/stubs/synapse/synapse_rust/lru_cache.pyi
new file mode 100644
index 0000000000..8ca07bc4dd
--- /dev/null
+++ b/stubs/synapse/synapse_rust/lru_cache.pyi
@@ -0,0 +1,48 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Generic, List, Optional, Set, TypeVar, Collection
+
+from synapse.util.caches.lrucache import LruCache
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+class LruCacheNode(Generic[KT, VT]):
+    key: KT
+    value: VT
+    memory: int
+    last_access_ts_secs: int
+
+    def __init__(
+        self,
+        cache: LruCache,
+        cache_list: "PerCacheLinkedList",
+        key: object,
+        value: object,
+        callbacks: Set[Callable[[], None]],
+        memory: int,
+        ts_secs: int,
+    ) -> None: ...
+    def add_callbacks(self, new_callbacks: Collection[Callable[[], None]]) -> None: ...
+    def run_and_clear_callbacks(self) -> None: ...
+    def drop_from_cache(self) -> None: ...
+    def drop_from_lists(self) -> None: ...
+    def move_to_front(self, ts_secs: int) -> None: ...
+
+class PerCacheLinkedList(Generic[KT, VT]):
+    def __init__(self) -> None: ...
+    def get_back(self) -> Optional[LruCacheNode[KT, VT]]: ...
+
+def get_global_list() -> List[LruCacheNode]: ...
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 7d1e405457..2df4b1004d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -44,7 +44,11 @@ from twisted.internet.interfaces import IReactorTime
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import get_jemalloc_stats
-from synapse.synapse_rust.lru_cache import LruCacheNode, PerCacheLinkedList
+from synapse.synapse_rust.lru_cache import (
+    LruCacheNode,
+    PerCacheLinkedList,
+    get_global_list,
+)
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
 from synapse.util.caches.treecache import (
@@ -52,7 +56,6 @@ from synapse.util.caches.treecache import (
     iterate_tree_cache_entry,
     iterate_tree_cache_items,
 )
-from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -95,22 +98,10 @@ VT = TypeVar("VT")
 T = TypeVar("T")
 
 
-class _TimedListNode(ListNode[T]):
-    """A `ListNode` that tracks last access time."""
-
-    __slots__ = ["last_access_ts_secs"]
-
-    def update_last_access(self, clock: Clock) -> None:
-        self.last_access_ts_secs = int(clock.time())
-
-
 # Whether to insert new cache entries to the global list. We only add to it if
 # time based eviction is enabled.
 USE_GLOBAL_LIST = False
 
-# A linked list of all cache entries, allowing efficient time based eviction.
-GLOBAL_ROOT = ListNode["_Node"].create_root_node()
-
 
 @wrap_as_background_process("LruCache._expire_old_entries")
 async def _expire_old_entries(
@@ -124,9 +115,12 @@ async def _expire_old_entries(
         target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
         min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
 
+    # A linked list of all cache entries, allowing efficient time based eviction.
+    global_root = get_global_list()
+
     now = int(clock.time())
-    node = GLOBAL_ROOT.prev_node
-    assert node is not None
+    assert len(global_root) > 0
+    node = global_root[0]
 
     i = 0
 
@@ -148,10 +142,7 @@ async def _expire_old_entries(
                 "Unable to read allocated memory, skipping memory-based cache eviction."
             )
 
-    while node is not GLOBAL_ROOT:
-        # Only the root node isn't a `_TimedListNode`.
-        assert isinstance(node, _TimedListNode)
-
+    for node in global_root[1:]:
         # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
         # nothing to do here
         if (
@@ -238,125 +229,6 @@ def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     )
 
 
-class _Node(Generic[KT, VT]):
-    __slots__ = [
-        "_list_node",
-        "_global_list_node",
-        "_cache",
-        "key",
-        "value",
-        "callbacks",
-        "memory",
-    ]
-
-    def __init__(
-        self,
-        root: "ListNode[_Node]",
-        key: KT,
-        value: VT,
-        cache: "weakref.ReferenceType[LruCache[KT, VT]]",
-        clock: Clock,
-        callbacks: Collection[Callable[[], None]] = (),
-        prune_unread_entries: bool = True,
-    ):
-        self._list_node = ListNode.insert_after(self, root)
-        self._global_list_node: Optional[_TimedListNode] = None
-        if USE_GLOBAL_LIST and prune_unread_entries:
-            self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
-            self._global_list_node.update_last_access(clock)
-
-        # We store a weak reference to the cache object so that this _Node can
-        # remove itself from the cache. If the cache is dropped we ensure we
-        # remove our entries in the lists.
-        self._cache = cache
-
-        self.key = key
-        self.value = value
-
-        # Set of callbacks to run when the node gets deleted. We store as a list
-        # rather than a set to keep memory usage down (and since we expect few
-        # entries per node, the performance of checking for duplication in a
-        # list vs using a set is negligible).
-        #
-        # Note that we store this as an optional list to keep the memory
-        # footprint down. Storing `None` is free as its a singleton, while empty
-        # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
-        # thing and used sets).
-        self.callbacks: Optional[List[Callable[[], None]]] = None
-
-        self.add_callbacks(callbacks)
-
-        self.memory = 0
-        if caches.TRACK_MEMORY_USAGE:
-            self.memory = (
-                _get_size_of(key)
-                + _get_size_of(value)
-                + _get_size_of(self._list_node, recurse=False)
-                + _get_size_of(self.callbacks, recurse=False)
-                + _get_size_of(self, recurse=False)
-            )
-            self.memory += _get_size_of(self.memory, recurse=False)
-
-            if self._global_list_node:
-                self.memory += _get_size_of(self._global_list_node, recurse=False)
-                self.memory += _get_size_of(self._global_list_node.last_access_ts_secs)
-
-    def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
-        """Add to stored list of callbacks, removing duplicates."""
-
-        if not callbacks:
-            return
-
-        if not self.callbacks:
-            self.callbacks = []
-
-        for callback in callbacks:
-            if callback not in self.callbacks:
-                self.callbacks.append(callback)
-
-    def run_and_clear_callbacks(self) -> None:
-        """Run all callbacks and clear the stored list of callbacks. Used when
-        the node is being deleted.
-        """
-
-        if not self.callbacks:
-            return
-
-        for callback in self.callbacks:
-            callback()
-
-        self.callbacks = None
-
-    def drop_from_cache(self) -> None:
-        """Drop this node from the cache.
-
-        Ensures that the entry gets removed from the cache and that we get
-        removed from all lists.
-        """
-        cache = self._cache()
-        if (
-            cache is None
-            or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
-        ):
-            # `cache.pop` should call `drop_from_lists()`, unless this Node had
-            # already been removed from the cache.
-            self.drop_from_lists()
-
-    def drop_from_lists(self) -> None:
-        """Remove this node from the cache lists."""
-        self._list_node.remove_from_list()
-
-        if self._global_list_node:
-            self._global_list_node.remove_from_list()
-
-    def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
-        """Moves this node to the front of all the lists its in."""
-        self._list_node.move_after(cache_list_root)
-        if self._global_list_node:
-            self._global_list_node.move_after(GLOBAL_ROOT)
-            self._global_list_node.update_last_access(clock)
-
-
 class _Sentinel(Enum):
     # defining a sentinel in this way allows mypy to correctly handle the
     # type of a dictionary lookup.
@@ -418,7 +290,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
+        cache: "Union[Dict[KT, LruCacheNode[KT, VT]], TreeCache]" = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -450,12 +322,10 @@ class LruCache(Generic[KT, VT]):
         self.metrics = metrics
 
         # We create a single weakref to self here so that we don't need to keep
-        # creating more each time we create a `_Node`.
+        # creating more each time we create a `LruCacheNode`.
         weak_ref_to_self = weakref.ref(self)
 
-        list_root = ListNode[_Node[KT, VT]].create_root_node()
-
-        rust_linked_list = PerCacheLinkedList()
+        rust_linked_list: "PerCacheLinkedList[KT, VT]" = PerCacheLinkedList()
 
         lock = threading.Lock()
 
@@ -497,13 +367,14 @@ class LruCache(Generic[KT, VT]):
         def add_node(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
-            node: _Node[KT, VT] = LruCacheNode(
+            node: "LruCacheNode[KT, VT]" = LruCacheNode(
                 self,
                 rust_linked_list,
                 key,
                 value,
                 set(callbacks),
                 0,
+                int(real_clock.time()),
             )
             cache[key] = node
 
@@ -513,10 +384,10 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node[KT, VT]) -> None:
-            node.move_to_front()
+        def move_node_to_front(node: "LruCacheNode[KT, VT]") -> None:
+            node.move_to_front(int(real_clock.time()))
 
-        def delete_node(node: _Node[KT, VT]) -> int:
+        def delete_node(node: "LruCacheNode[KT, VT]") -> int:
             node.drop_from_lists()
 
             deleted_len = 1
@@ -635,7 +506,7 @@ class LruCache(Generic[KT, VT]):
                 if update_metrics and metrics:
                     metrics.inc_hits()
 
-                # We store entries in the `TreeCache` with values of type `_Node`,
+                # We store entries in the `TreeCache` with values of type `LruCacheNode`,
                 # which we need to unwrap.
                 return (
                     (full_key, lru_node.value)
@@ -730,8 +601,8 @@ class LruCache(Generic[KT, VT]):
                 node.run_and_clear_callbacks()
                 node.drop_from_lists()
 
-            assert list_root.next_node == list_root
-            assert list_root.prev_node == list_root
+            # assert list_root.next_node == list_root
+            # assert list_root.prev_node == list_root
 
             cache.clear()
             if size_callback:
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
deleted file mode 100644
index 8efbf061aa..0000000000
--- a/synapse/util/linked_list.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A circular doubly linked list implementation.
-"""
-
-import threading
-from typing import Generic, Optional, Type, TypeVar
-
-P = TypeVar("P")
-LN = TypeVar("LN", bound="ListNode")
-
-
-class ListNode(Generic[P]):
-    """A node in a circular doubly linked list, with an (optional) reference to
-    a cache entry.
-
-    The reference should only be `None` for the root node or if the node has
-    been removed from the list.
-    """
-
-    # A lock to protect mutating the list prev/next pointers.
-    _LOCK = threading.Lock()
-
-    # We don't use attrs here as in py3.6 you can't have `attr.s(slots=True)`
-    # and inherit from `Generic` for some reason
-    __slots__ = [
-        "cache_entry",
-        "prev_node",
-        "next_node",
-    ]
-
-    def __init__(self, cache_entry: Optional[P] = None) -> None:
-        self.cache_entry = cache_entry
-        self.prev_node: Optional[ListNode[P]] = None
-        self.next_node: Optional[ListNode[P]] = None
-
-    @classmethod
-    def create_root_node(cls: Type["ListNode[P]"]) -> "ListNode[P]":
-        """Create a new linked list by creating a "root" node, which is a node
-        that has prev_node/next_node pointing to itself and no associated cache
-        entry.
-        """
-        root = cls()
-        root.prev_node = root
-        root.next_node = root
-        return root
-
-    @classmethod
-    def insert_after(
-        cls: Type[LN],
-        cache_entry: P,
-        node: "ListNode[P]",
-    ) -> LN:
-        """Create a new list node that is placed after the given node.
-
-        Args:
-            cache_entry: The associated cache entry.
-            node: The existing node in the list to insert the new entry after.
-        """
-        new_node = cls(cache_entry)
-        with cls._LOCK:
-            new_node._refs_insert_after(node)
-        return new_node
-
-    def remove_from_list(self) -> None:
-        """Remove this node from the list."""
-        with self._LOCK:
-            self._refs_remove_node_from_list()
-
-        # We drop the reference to the cache entry to break the reference cycle
-        # between the list node and cache entry, allowing the two to be dropped
-        # immediately rather than at the next GC.
-        self.cache_entry = None
-
-    def move_after(self, node: "ListNode[P]") -> None:
-        """Move this node from its current location in the list to after the
-        given node.
-        """
-        with self._LOCK:
-            # We assert that both this node and the target node is still "alive".
-            assert self.prev_node
-            assert self.next_node
-            assert node.prev_node
-            assert node.next_node
-
-            assert self is not node
-
-            # Remove self from the list
-            self._refs_remove_node_from_list()
-
-            # Insert self back into the list, after target node
-            self._refs_insert_after(node)
-
-    def _refs_remove_node_from_list(self) -> None:
-        """Internal method to *just* remove the node from the list, without
-        e.g. clearing out the cache entry.
-        """
-        if self.prev_node is None or self.next_node is None:
-            # We've already been removed from the list.
-            return
-
-        prev_node = self.prev_node
-        next_node = self.next_node
-
-        prev_node.next_node = next_node
-        next_node.prev_node = prev_node
-
-        # We set these to None so that we don't get circular references,
-        # allowing us to be dropped without having to go via the GC.
-        self.prev_node = None
-        self.next_node = None
-
-    def _refs_insert_after(self, node: "ListNode[P]") -> None:
-        """Internal method to insert the node after the given node."""
-
-        # This method should only be called when we're not already in the list.
-        assert self.prev_node is None
-        assert self.next_node is None
-
-        # We expect the given node to be in the list and thus have valid
-        # prev/next refs.
-        assert node.next_node
-        assert node.prev_node
-
-        prev_node = node
-        next_node = node.next_node
-
-        self.prev_node = prev_node
-        self.next_node = next_node
-
-        prev_node.next_node = self
-        next_node.prev_node = self
-
-    def get_cache_entry(self) -> Optional[P]:
-        """Get the cache entry, returns None if this is the root node (i.e.
-        cache_entry is None) or if the entry has been dropped.
-        """
-        return self.cache_entry