diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 0a9760cafc..394d4e799c 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -18,4 +18,7 @@ crate-type = ["cdylib"]
name = "synapse.synapse_rust"
[dependencies]
+intrusive-collections = "0.9.4"
+lazy_static = "1.4.0"
+log = "0.4.17"
pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] }
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 142fc2ed93..dc01c623a9 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -1,5 +1,7 @@
use pyo3::prelude::*;
+mod lru_cache;
+
/// Formats the sum of two numbers as string.
#[pyfunction]
#[pyo3(text_signature = "(a, b, /)")]
@@ -9,8 +11,9 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
/// The entry point for defining the Python module.
#[pymodule]
-fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
+ lru_cache::register_module(py, m)?;
Ok(())
}
diff --git a/rust/src/lru_cache.rs b/rust/src/lru_cache.rs
new file mode 100644
index 0000000000..ac36e9162d
--- /dev/null
+++ b/rust/src/lru_cache.rs
@@ -0,0 +1,232 @@
+use std::sync::{Arc, Mutex};
+
+use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink};
+use intrusive_collections::{LinkedList, LinkedListLink};
+use lazy_static::lazy_static;
+use log::error;
+use pyo3::prelude::*;
+use pyo3::types::PySet;
+
+/// Called when registering modules with python.
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+ let child_module = PyModule::new(py, "push")?;
+ child_module.add_class::<LruCacheNode>()?;
+ child_module.add_class::<PerCacheLinkedList>()?;
+ child_module.add_function(wrap_pyfunction!(get_global_list, m)?)?;
+
+ m.add_submodule(child_module)?;
+
+ // We need to manually add the module to sys.modules to make `from
+ // synapse.synapse_rust import push` work.
+ py.import("sys")?
+ .getattr("modules")?
+ .set_item("synapse.synapse_rust.lru_cache", child_module)?;
+
+ Ok(())
+}
+
+#[pyclass]
+#[derive(Clone)]
+struct PerCacheLinkedList(Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>);
+
+#[pymethods]
+impl PerCacheLinkedList {
+ #[new]
+ fn new() -> PerCacheLinkedList {
+ PerCacheLinkedList(Default::default())
+ }
+
+ fn get_back(&self) -> Option<LruCacheNode> {
+ let list = self.0.lock().expect("poisoned");
+ list.back().clone_pointer().map(|n| LruCacheNode(n))
+ }
+}
+
+struct LruCacheNodeInner {
+ per_cache_link: LinkedListAtomicLink,
+ global_list_link: LinkedListAtomicLink,
+ per_cache_list: Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>,
+ cache: Mutex<Option<PyObject>>,
+ key: PyObject,
+ value: PyObject,
+ callbacks: Py<PySet>,
+ memory: usize,
+}
+
+#[pyclass]
+struct LruCacheNode(Arc<LruCacheNodeInner>);
+
+#[pymethods]
+impl LruCacheNode {
+ #[new]
+ fn py_new(
+ cache: PyObject,
+ cache_list: PerCacheLinkedList,
+ key: PyObject,
+ value: PyObject,
+ callbacks: Py<PySet>,
+ memory: usize,
+ ) -> Self {
+ let node = Arc::new(LruCacheNodeInner {
+ per_cache_link: Default::default(),
+ global_list_link: Default::default(),
+ per_cache_list: cache_list.0,
+ cache: Mutex::new(Some(cache)),
+ key,
+ value,
+ callbacks,
+ memory,
+ });
+
+ GLOBAL_LIST
+ .lock()
+ .expect("posioned")
+ .push_front(node.clone());
+
+ node.per_cache_list
+ .lock()
+ .expect("posioned")
+ .push_front(node.clone());
+
+ LruCacheNode(node)
+ }
+
+ fn add_callbacks(&self, py: Python<'_>, callbacks: Py<PySet>) -> PyResult<()> {
+ let new_callbacks = callbacks.as_ref(py);
+ let current_callbacks = self.0.callbacks.as_ref(py);
+
+ for cb in new_callbacks {
+ current_callbacks.add(cb)?;
+ }
+
+ Ok(())
+ }
+
+ fn run_and_clear_callbacks(&self, py: Python<'_>) {
+ let current_callbacks = self.0.callbacks.as_ref(py);
+
+ if current_callbacks.len() == 0 {
+ return;
+ }
+
+ // Swap out the stored callbacks with an empty list
+ let callbacks = std::mem::replace(&mut *callback_guard, Vec::new());
+
+ // Drop the lock
+ std::mem::drop(callback_guard);
+
+ for callback in callbacks {
+ if let Err(err) = callback.call0(py) {
+ error!("LruCacheNode callback errored: {err}");
+ }
+ }
+ }
+
+ fn drop_from_cache(&self) -> PyResult<()> {
+ let cache = self.0.cache.lock().expect("poisoned").take();
+
+ if let Some(cache) = cache {
+ Python::with_gil(|py| cache.call_method1(py, "pop", (&self.0.key, None::<()>)))?;
+ }
+
+ self.drop_from_lists();
+
+ Ok(())
+ }
+
+ fn drop_from_lists(&self) {
+ if self.0.global_list_link.is_linked() {
+ let mut glboal_list = GLOBAL_LIST.lock().expect("poisoned");
+
+ let mut curor_mut = unsafe {
+ // Getting the cursor is unsafe as we need to ensure the list link
+ // belongs to the given list.
+ glboal_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+ };
+
+ curor_mut.remove();
+ }
+
+ if self.0.per_cache_link.is_linked() {
+ let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned");
+
+ let mut curor_mut = unsafe {
+ // Getting the cursor is unsafe as we need to ensure the list link
+ // belongs to the given list.
+ per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+ };
+
+ curor_mut.remove();
+ }
+ }
+
+ fn move_to_front(&self) {
+ if self.0.global_list_link.is_linked() {
+ let mut global_list = GLOBAL_LIST.lock().expect("poisoned");
+
+ let mut curor_mut = unsafe {
+ // Getting the cursor is unsafe as we need to ensure the list link
+ // belongs to the given list.
+ global_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+ };
+ curor_mut.remove();
+
+ global_list.push_front(self.0.clone());
+ }
+
+ if self.0.per_cache_link.is_linked() {
+ let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned");
+
+ let mut curor_mut = unsafe {
+ // Getting the cursor is unsafe as we need to ensure the list link
+ // belongs to the given list.
+ per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+ };
+
+ curor_mut.remove();
+
+ per_cache_list.push_front(self.0.clone());
+ }
+ }
+
+ #[getter]
+ fn key(&self) -> &PyObject {
+ &self.0.key
+ }
+
+ #[getter]
+ fn value(&self) -> &PyObject {
+ &self.0.value
+ }
+
+ #[getter]
+ fn memory(&self) -> usize {
+ self.0.memory
+ }
+}
+
+#[pyfunction]
+fn get_global_list() -> Vec<LruCacheNode> {
+ let list = GLOBAL_LIST.lock().expect("poisoned");
+
+ let mut vec = Vec::new();
+
+ let mut cursor = list.front();
+
+ while let Some(n) = cursor.clone_pointer() {
+ vec.push(LruCacheNode(n));
+
+ cursor.move_next();
+ }
+
+ vec
+}
+
+intrusive_adapter!(LruCacheNodeAdapterPerCache = Arc<LruCacheNodeInner>: LruCacheNodeInner { per_cache_link: LinkedListLink });
+intrusive_adapter!(LruCacheNodeAdapterGlobal = Arc<LruCacheNodeInner>: LruCacheNodeInner { global_list_link: LinkedListLink });
+
+lazy_static! {
+ static ref GLOBAL_LIST_ADAPTER: LruCacheNodeAdapterGlobal = LruCacheNodeAdapterGlobal::new();
+ static ref GLOBAL_LIST: Arc<Mutex<LinkedList<LruCacheNodeAdapterGlobal>>> =
+ Arc::new(Mutex::new(LinkedList::new(GLOBAL_LIST_ADAPTER.clone())));
+}
|