summary refs log tree commit diff
path: root/rust
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-09-06 22:24:46 +0100
committerErik Johnston <erik@matrix.org>2022-09-09 16:22:45 +0100
commit53e83c76b2f153f934086687c74d8bd88380674f (patch)
tree681b7e1d1ec80b9f9f5f387a60d92f8f6f1d318a /rust
parentUse an upsert for `receipts_graph`. (#13752) (diff)
downloadsynapse-53e83c76b2f153f934086687c74d8bd88380674f.tar.xz
SNAPSHOT
Diffstat (limited to 'rust')
-rw-r--r--rust/Cargo.toml3
-rw-r--r--rust/src/lib.rs5
-rw-r--r--rust/src/lru_cache.rs232
3 files changed, 239 insertions, 1 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 0a9760cafc..394d4e799c 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -18,4 +18,7 @@ crate-type = ["cdylib"]
 name = "synapse.synapse_rust"
 
 [dependencies]
+intrusive-collections = "0.9.4"
+lazy_static = "1.4.0"
+log = "0.4.17"
 pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] }
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 142fc2ed93..dc01c623a9 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -1,5 +1,7 @@
 use pyo3::prelude::*;
 
+mod lru_cache;
+
 /// Formats the sum of two numbers as string.
 #[pyfunction]
 #[pyo3(text_signature = "(a, b, /)")]
@@ -9,8 +11,9 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
 
 /// The entry point for defining the Python module.
 #[pymodule]
-fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
 
+    lru_cache::register_module(py, m)?;
     Ok(())
 }
diff --git a/rust/src/lru_cache.rs b/rust/src/lru_cache.rs
new file mode 100644
index 0000000000..ac36e9162d
--- /dev/null
+++ b/rust/src/lru_cache.rs
@@ -0,0 +1,232 @@
+use std::sync::{Arc, Mutex};
+
+use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink};
+use intrusive_collections::{LinkedList, LinkedListLink};
+use lazy_static::lazy_static;
+use log::error;
+use pyo3::prelude::*;
+use pyo3::types::PySet;
+
+/// Called when registering modules with python.
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+    let child_module = PyModule::new(py, "push")?;
+    child_module.add_class::<LruCacheNode>()?;
+    child_module.add_class::<PerCacheLinkedList>()?;
+    child_module.add_function(wrap_pyfunction!(get_global_list, m)?)?;
+
+    m.add_submodule(child_module)?;
+
+    // We need to manually add the module to sys.modules to make `from
+    // synapse.synapse_rust import push` work.
+    py.import("sys")?
+        .getattr("modules")?
+        .set_item("synapse.synapse_rust.lru_cache", child_module)?;
+
+    Ok(())
+}
+
+#[pyclass]
+#[derive(Clone)]
+struct PerCacheLinkedList(Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>);
+
+#[pymethods]
+impl PerCacheLinkedList {
+    #[new]
+    fn new() -> PerCacheLinkedList {
+        PerCacheLinkedList(Default::default())
+    }
+
+    fn get_back(&self) -> Option<LruCacheNode> {
+        let list = self.0.lock().expect("poisoned");
+        list.back().clone_pointer().map(|n| LruCacheNode(n))
+    }
+}
+
+struct LruCacheNodeInner {
+    per_cache_link: LinkedListAtomicLink,
+    global_list_link: LinkedListAtomicLink,
+    per_cache_list: Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>,
+    cache: Mutex<Option<PyObject>>,
+    key: PyObject,
+    value: PyObject,
+    callbacks: Py<PySet>,
+    memory: usize,
+}
+
+#[pyclass]
+struct LruCacheNode(Arc<LruCacheNodeInner>);
+
+#[pymethods]
+impl LruCacheNode {
+    #[new]
+    fn py_new(
+        cache: PyObject,
+        cache_list: PerCacheLinkedList,
+        key: PyObject,
+        value: PyObject,
+        callbacks: Py<PySet>,
+        memory: usize,
+    ) -> Self {
+        let node = Arc::new(LruCacheNodeInner {
+            per_cache_link: Default::default(),
+            global_list_link: Default::default(),
+            per_cache_list: cache_list.0,
+            cache: Mutex::new(Some(cache)),
+            key,
+            value,
+            callbacks,
+            memory,
+        });
+
+        GLOBAL_LIST
+            .lock()
+            .expect("posioned")
+            .push_front(node.clone());
+
+        node.per_cache_list
+            .lock()
+            .expect("posioned")
+            .push_front(node.clone());
+
+        LruCacheNode(node)
+    }
+
+    fn add_callbacks(&self, py: Python<'_>, callbacks: Py<PySet>) -> PyResult<()> {
+        let new_callbacks = callbacks.as_ref(py);
+        let current_callbacks = self.0.callbacks.as_ref(py);
+
+        for cb in new_callbacks {
+            current_callbacks.add(cb)?;
+        }
+
+        Ok(())
+    }
+
+    fn run_and_clear_callbacks(&self, py: Python<'_>) {
+        let current_callbacks = self.0.callbacks.as_ref(py);
+
+        if current_callbacks.len() == 0 {
+            return;
+        }
+
+        // Swap out the stored callbacks with an empty list
+        let callbacks = std::mem::replace(&mut *callback_guard, Vec::new());
+
+        // Drop the lock
+        std::mem::drop(callback_guard);
+
+        for callback in callbacks {
+            if let Err(err) = callback.call0(py) {
+                error!("LruCacheNode callback errored: {err}");
+            }
+        }
+    }
+
+    fn drop_from_cache(&self) -> PyResult<()> {
+        let cache = self.0.cache.lock().expect("poisoned").take();
+
+        if let Some(cache) = cache {
+            Python::with_gil(|py| cache.call_method1(py, "pop", (&self.0.key, None::<()>)))?;
+        }
+
+        self.drop_from_lists();
+
+        Ok(())
+    }
+
+    fn drop_from_lists(&self) {
+        if self.0.global_list_link.is_linked() {
+            let mut glboal_list = GLOBAL_LIST.lock().expect("poisoned");
+
+            let mut curor_mut = unsafe {
+                // Getting the cursor is unsafe as we need to ensure the list link
+                // belongs to the given list.
+                glboal_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+            };
+
+            curor_mut.remove();
+        }
+
+        if self.0.per_cache_link.is_linked() {
+            let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned");
+
+            let mut curor_mut = unsafe {
+                // Getting the cursor is unsafe as we need to ensure the list link
+                // belongs to the given list.
+                per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+            };
+
+            curor_mut.remove();
+        }
+    }
+
+    fn move_to_front(&self) {
+        if self.0.global_list_link.is_linked() {
+            let mut global_list = GLOBAL_LIST.lock().expect("poisoned");
+
+            let mut curor_mut = unsafe {
+                // Getting the cursor is unsafe as we need to ensure the list link
+                // belongs to the given list.
+                global_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+            };
+            curor_mut.remove();
+
+            global_list.push_front(self.0.clone());
+        }
+
+        if self.0.per_cache_link.is_linked() {
+            let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned");
+
+            let mut curor_mut = unsafe {
+                // Getting the cursor is unsafe as we need to ensure the list link
+                // belongs to the given list.
+                per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone()))
+            };
+
+            curor_mut.remove();
+
+            per_cache_list.push_front(self.0.clone());
+        }
+    }
+
+    #[getter]
+    fn key(&self) -> &PyObject {
+        &self.0.key
+    }
+
+    #[getter]
+    fn value(&self) -> &PyObject {
+        &self.0.value
+    }
+
+    #[getter]
+    fn memory(&self) -> usize {
+        self.0.memory
+    }
+}
+
+#[pyfunction]
+fn get_global_list() -> Vec<LruCacheNode> {
+    let list = GLOBAL_LIST.lock().expect("poisoned");
+
+    let mut vec = Vec::new();
+
+    let mut cursor = list.front();
+
+    while let Some(n) = cursor.clone_pointer() {
+        vec.push(LruCacheNode(n));
+
+        cursor.move_next();
+    }
+
+    vec
+}
+
+intrusive_adapter!(LruCacheNodeAdapterPerCache = Arc<LruCacheNodeInner>: LruCacheNodeInner { per_cache_link: LinkedListLink });
+intrusive_adapter!(LruCacheNodeAdapterGlobal = Arc<LruCacheNodeInner>: LruCacheNodeInner { global_list_link: LinkedListLink });
+
+lazy_static! {
+    static ref GLOBAL_LIST_ADAPTER: LruCacheNodeAdapterGlobal = LruCacheNodeAdapterGlobal::new();
+    static ref GLOBAL_LIST: Arc<Mutex<LinkedList<LruCacheNodeAdapterGlobal>>> =
+        Arc::new(Mutex::new(LinkedList::new(GLOBAL_LIST_ADAPTER.clone())));
+}