diff options
author | Erik Johnston <erik@matrix.org> | 2022-09-06 22:24:46 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-09-09 16:22:45 +0100 |
commit | 53e83c76b2f153f934086687c74d8bd88380674f (patch) | |
tree | 681b7e1d1ec80b9f9f5f387a60d92f8f6f1d318a /rust | |
parent | Use an upsert for `receipts_graph`. (#13752) (diff) | |
download | synapse-53e83c76b2f153f934086687c74d8bd88380674f.tar.xz |
SNAPSHOT
Diffstat (limited to 'rust')
-rw-r--r-- | rust/Cargo.toml | 3 | ||||
-rw-r--r-- | rust/src/lib.rs | 5 | ||||
-rw-r--r-- | rust/src/lru_cache.rs | 232 |
3 files changed, 239 insertions, 1 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 0a9760cafc..394d4e799c 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -18,4 +18,7 @@ crate-type = ["cdylib"] name = "synapse.synapse_rust" [dependencies] +intrusive-collections = "0.9.4" +lazy_static = "1.4.0" +log = "0.4.17" pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 142fc2ed93..dc01c623a9 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,7 @@ use pyo3::prelude::*; +mod lru_cache; + /// Formats the sum of two numbers as string. #[pyfunction] #[pyo3(text_signature = "(a, b, /)")] @@ -9,8 +11,9 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> { /// The entry point for defining the Python module. #[pymodule] -fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + lru_cache::register_module(py, m)?; Ok(()) } diff --git a/rust/src/lru_cache.rs b/rust/src/lru_cache.rs new file mode 100644 index 0000000000..ac36e9162d --- /dev/null +++ b/rust/src/lru_cache.rs @@ -0,0 +1,232 @@ +use std::sync::{Arc, Mutex}; + +use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink}; +use intrusive_collections::{LinkedList, LinkedListLink}; +use lazy_static::lazy_static; +use log::error; +use pyo3::prelude::*; +use pyo3::types::PySet; + +/// Called when registering modules with python. +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "push")?; + child_module.add_class::<LruCacheNode>()?; + child_module.add_class::<PerCacheLinkedList>()?; + child_module.add_function(wrap_pyfunction!(get_global_list, m)?)?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.lru_cache", child_module)?; + + Ok(()) +} + +#[pyclass] +#[derive(Clone)] +struct PerCacheLinkedList(Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>); + +#[pymethods] +impl PerCacheLinkedList { + #[new] + fn new() -> PerCacheLinkedList { + PerCacheLinkedList(Default::default()) + } + + fn get_back(&self) -> Option<LruCacheNode> { + let list = self.0.lock().expect("poisoned"); + list.back().clone_pointer().map(|n| LruCacheNode(n)) + } +} + +struct LruCacheNodeInner { + per_cache_link: LinkedListAtomicLink, + global_list_link: LinkedListAtomicLink, + per_cache_list: Arc<Mutex<LinkedList<LruCacheNodeAdapterPerCache>>>, + cache: Mutex<Option<PyObject>>, + key: PyObject, + value: PyObject, + callbacks: Py<PySet>, + memory: usize, +} + +#[pyclass] +struct LruCacheNode(Arc<LruCacheNodeInner>); + +#[pymethods] +impl LruCacheNode { + #[new] + fn py_new( + cache: PyObject, + cache_list: PerCacheLinkedList, + key: PyObject, + value: PyObject, + callbacks: Py<PySet>, + memory: usize, + ) -> Self { + let node = Arc::new(LruCacheNodeInner { + per_cache_link: Default::default(), + global_list_link: Default::default(), + per_cache_list: cache_list.0, + cache: Mutex::new(Some(cache)), + key, + value, + callbacks, + memory, + }); + + GLOBAL_LIST + .lock() + .expect("posioned") + .push_front(node.clone()); + + node.per_cache_list + .lock() + .expect("posioned") + .push_front(node.clone()); + + LruCacheNode(node) + } + + fn add_callbacks(&self, py: Python<'_>, callbacks: Py<PySet>) -> PyResult<()> { + let new_callbacks = callbacks.as_ref(py); + let current_callbacks = self.0.callbacks.as_ref(py); + + for cb in new_callbacks { + current_callbacks.add(cb)?; + } + + Ok(()) + } + + fn run_and_clear_callbacks(&self, py: Python<'_>) { + let current_callbacks = self.0.callbacks.as_ref(py); + + if current_callbacks.len() == 0 { + return; + } + + // Swap out the stored callbacks with an empty list + let callbacks = std::mem::replace(&mut *callback_guard, Vec::new()); + + // Drop the lock + std::mem::drop(callback_guard); + + for callback in callbacks { + if let Err(err) = callback.call0(py) { + error!("LruCacheNode callback errored: {err}"); + } + } + } + + fn drop_from_cache(&self) -> PyResult<()> { + let cache = self.0.cache.lock().expect("poisoned").take(); + + if let Some(cache) = cache { + Python::with_gil(|py| cache.call_method1(py, "pop", (&self.0.key, None::<()>)))?; + } + + self.drop_from_lists(); + + Ok(()) + } + + fn drop_from_lists(&self) { + if self.0.global_list_link.is_linked() { + let mut glboal_list = GLOBAL_LIST.lock().expect("poisoned"); + + let mut curor_mut = unsafe { + // Getting the cursor is unsafe as we need to ensure the list link + // belongs to the given list. + glboal_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone())) + }; + + curor_mut.remove(); + } + + if self.0.per_cache_link.is_linked() { + let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned"); + + let mut curor_mut = unsafe { + // Getting the cursor is unsafe as we need to ensure the list link + // belongs to the given list. + per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone())) + }; + + curor_mut.remove(); + } + } + + fn move_to_front(&self) { + if self.0.global_list_link.is_linked() { + let mut global_list = GLOBAL_LIST.lock().expect("poisoned"); + + let mut curor_mut = unsafe { + // Getting the cursor is unsafe as we need to ensure the list link + // belongs to the given list. + global_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone())) + }; + curor_mut.remove(); + + global_list.push_front(self.0.clone()); + } + + if self.0.per_cache_link.is_linked() { + let mut per_cache_list = self.0.per_cache_list.lock().expect("poisoned"); + + let mut curor_mut = unsafe { + // Getting the cursor is unsafe as we need to ensure the list link + // belongs to the given list. + per_cache_list.cursor_mut_from_ptr(Arc::into_raw(self.0.clone())) + }; + + curor_mut.remove(); + + per_cache_list.push_front(self.0.clone()); + } + } + + #[getter] + fn key(&self) -> &PyObject { + &self.0.key + } + + #[getter] + fn value(&self) -> &PyObject { + &self.0.value + } + + #[getter] + fn memory(&self) -> usize { + self.0.memory + } +} + +#[pyfunction] +fn get_global_list() -> Vec<LruCacheNode> { + let list = GLOBAL_LIST.lock().expect("poisoned"); + + let mut vec = Vec::new(); + + let mut cursor = list.front(); + + while let Some(n) = cursor.clone_pointer() { + vec.push(LruCacheNode(n)); + + cursor.move_next(); + } + + vec +} + +intrusive_adapter!(LruCacheNodeAdapterPerCache = Arc<LruCacheNodeInner>: LruCacheNodeInner { per_cache_link: LinkedListLink }); +intrusive_adapter!(LruCacheNodeAdapterGlobal = Arc<LruCacheNodeInner>: LruCacheNodeInner { global_list_link: LinkedListLink }); + +lazy_static! { + static ref GLOBAL_LIST_ADAPTER: LruCacheNodeAdapterGlobal = LruCacheNodeAdapterGlobal::new(); + static ref GLOBAL_LIST: Arc<Mutex<LinkedList<LruCacheNodeAdapterGlobal>>> = + Arc::new(Mutex::new(LinkedList::new(GLOBAL_LIST_ADAPTER.clone()))); +} |