From 18ac015ecdd66fde7275efaaf07995b35ad4c41b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 14:03:28 +0000 Subject: bindings --- rust/src/lib.rs | 1 + rust/src/tree_cache.rs | 344 --------------------------------------- rust/src/tree_cache/binding.rs | 128 +++++++++++++++ rust/src/tree_cache/mod.rs | 360 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 489 insertions(+), 344 deletions(-) delete mode 100644 rust/src/tree_cache.rs create mode 100644 rust/src/tree_cache/binding.rs create mode 100644 rust/src/tree_cache/mod.rs (limited to 'rust/src') diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 00f72dc59f..6db2b1eae2 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -27,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; push::register_module(py, m)?; + tree_cache::binding::register_module(py, m)?; Ok(()) } diff --git a/rust/src/tree_cache.rs b/rust/src/tree_cache.rs deleted file mode 100644 index 6796229d64..0000000000 --- a/rust/src/tree_cache.rs +++ /dev/null @@ -1,344 +0,0 @@ -use std::{collections::HashMap, hash::Hash}; - -use anyhow::{bail, Error}; - -pub enum TreeCacheNode { - Leaf(V), - Branch(usize, HashMap>), -} - -impl TreeCacheNode { - pub fn new_branch() -> Self { - TreeCacheNode::Branch(0, Default::default()) - } - - fn len(&self) -> usize { - match self { - TreeCacheNode::Leaf(_) => 1, - TreeCacheNode::Branch(size, _) => *size, - } - } -} - -impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode { - pub fn set( - &mut self, - mut key: impl Iterator, - value: V, - ) -> Result<(usize, usize), Error> { - if let Some(k) = key.next() { - match self { - TreeCacheNode::Leaf(_) => bail!("Given key is too long"), - TreeCacheNode::Branch(size, map) => { - let node = map.entry(k).or_insert_with(TreeCacheNode::new_branch); - let (added, removed) = node.set(key, value)?; - - *size += added; - *size -= removed; - - Ok((added, removed)) - } - } - } else { - let added = if let TreeCacheNode::Branch(_, map) = self { - (1, map.len()) - } else { - (0, 0) - }; - - *self = TreeCacheNode::Leaf(value); - - Ok(added) - } - } - - pub fn pop( - &mut self, - current_key: &K, - mut next_keys: impl Iterator, - ) -> Result>, Error> { - if let Some(next_key) = next_keys.next() { - match self { - TreeCacheNode::Leaf(_) => bail!("Given key is too long"), - TreeCacheNode::Branch(size, map) => { - let node = if let Some(node) = map.get_mut(current_key) { - node - } else { - return Ok(None); - }; - - if let Some(popped) = node.pop(next_key, next_keys)? { - *size -= node.len(); - - Ok(Some(popped)) - } else { - Ok(None) - } - } - } - } else { - match self { - TreeCacheNode::Leaf(_) => bail!("Given key is too long"), - TreeCacheNode::Branch(size, map) => { - if let Some(node) = map.remove(current_key) { - *size -= node.len(); - - Ok(Some(node)) - } else { - Ok(None) - } - } - } - } - } - - pub fn items(&self) -> impl Iterator, &V)> { - let mut stack = vec![(vec![], self)]; - - std::iter::from_fn(move || { - while let Some((prefix, node)) = stack.pop() { - match node { - TreeCacheNode::Leaf(value) => return Some((prefix, value)), - TreeCacheNode::Branch(_, map) => { - stack.extend(map.iter().map(|(k, v)| { - let mut prefix = prefix.clone(); - prefix.push(k); - (prefix, v) - })); - } - } - } - - None - }) - } -} - -pub struct TreeCache { - root: TreeCacheNode, -} - -impl<'a, K: Eq + Hash + 'a, V> TreeCache { - pub fn new() -> Self { - TreeCache { - root: TreeCacheNode::new_branch(), - } - } - - pub fn set(&mut self, key: impl IntoIterator, value: V) -> Result<(), Error> { - self.root.set(key.into_iter(), value)?; - - Ok(()) - } - - pub fn get_node( - &self, - key: impl IntoIterator, - ) -> Result>, Error> { - let mut node = &self.root; - - for k in key { - match node { - TreeCacheNode::Leaf(_) => bail!("Given key is too long"), - TreeCacheNode::Branch(_, map) => { - node = if let Some(node) = map.get(k) { - node - } else { - return Ok(None); - }; - } - } - } - - Ok(Some(node)) - } - - pub fn get(&self, key: impl IntoIterator) -> Result, Error> { - if let Some(node) = self.get_node(key)? { - match node { - TreeCacheNode::Leaf(value) => Ok(Some(value)), - TreeCacheNode::Branch(_, _) => bail!("Given key is too short"), - } - } else { - Ok(None) - } - } - - pub fn pop_node( - &mut self, - key: impl IntoIterator, - ) -> Result>, Error> { - let mut key_iter = key.into_iter(); - - let k = if let Some(k) = key_iter.next() { - k - } else { - let node = std::mem::replace(&mut self.root, TreeCacheNode::new_branch()); - return Ok(Some(node)); - }; - - self.root.pop(k, key_iter) - } - - pub fn pop(&mut self, key: &[K]) -> Result, Error> { - if let Some(node) = self.pop_node(key)? { - match node { - TreeCacheNode::Leaf(value) => Ok(Some(value)), - TreeCacheNode::Branch(_, _) => bail!("Given key is too short"), - } - } else { - Ok(None) - } - } - - pub fn clear(&mut self) { - self.root = TreeCacheNode::new_branch(); - } - - pub fn len(&self) -> usize { - match self.root { - TreeCacheNode::Leaf(_) => 1, - TreeCacheNode::Branch(size, _) => size, - } - } - - pub fn values(&self) -> impl Iterator { - let mut stack = vec![&self.root]; - - std::iter::from_fn(move || { - while let Some(node) = stack.pop() { - match node { - TreeCacheNode::Leaf(value) => return Some(value), - TreeCacheNode::Branch(_, map) => { - stack.extend(map.values()); - } - } - } - - None - }) - } - - pub fn items(&self) -> impl Iterator, &V)> { - self.root.items() - } -} - -#[cfg(test)] -mod test { - use std::collections::BTreeSet; - - use super::*; - - #[test] - fn get_set() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - - assert_eq!(cache.get(&["a", "b"])?, Some(&"c")); - - let node = cache.get_node(&["a"])?.unwrap(); - - match node { - TreeCacheNode::Leaf(_) => bail!("expected branch"), - TreeCacheNode::Branch(_, map) => { - assert_eq!(map.len(), 1); - assert!(map.contains_key("b")); - } - } - - Ok(()) - } - - #[test] - fn length() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - - assert_eq!(cache.len(), 1); - - cache.set(vec!["a", "b"], "d")?; - - assert_eq!(cache.len(), 1); - - cache.set(vec!["e", "f"], "g")?; - - assert_eq!(cache.len(), 2); - - cache.set(vec!["e", "h"], "i")?; - - assert_eq!(cache.len(), 3); - - cache.set(vec!["e"], "i")?; - - assert_eq!(cache.len(), 2); - - cache.pop_node(&["a"])?; - - assert_eq!(cache.len(), 1); - - Ok(()) - } - - #[test] - fn clear() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - - assert_eq!(cache.len(), 1); - - cache.clear(); - - assert_eq!(cache.len(), 0); - - assert_eq!(cache.get(&["a", "b"])?, None); - - Ok(()) - } - - #[test] - fn pop() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - assert_eq!(cache.pop(&["a", "b"])?, Some("c")); - assert_eq!(cache.pop(&["a", "b"])?, None); - - Ok(()) - } - - #[test] - fn values() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - - let expected = ["c"].iter().collect(); - assert_eq!(cache.values().collect::>(), expected); - - cache.set(vec!["d", "e"], "f")?; - - let expected = ["c", "f"].iter().collect(); - assert_eq!(cache.values().collect::>(), expected); - - Ok(()) - } - - #[test] - fn items() -> Result<(), Error> { - let mut cache = TreeCache::new(); - - cache.set(vec!["a", "b"], "c")?; - cache.set(vec!["d", "e"], "f")?; - - let expected = [(vec![&"a", &"b"], &"c"), (vec![&"d", &"e"], &"f")] - .into_iter() - .collect(); - assert_eq!(cache.items().collect::>(), expected); - - Ok(()) - } -} diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs new file mode 100644 index 0000000000..70207f8781 --- /dev/null +++ b/rust/src/tree_cache/binding.rs @@ -0,0 +1,128 @@ +use std::hash::Hash; + +use anyhow::Error; +use pyo3::{ + pyclass, pymethods, types::PyModule, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, +}; + +use super::TreeCache; + +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "tree_cache")?; + child_module.add_class::()?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.tree_cache", child_module)?; + + Ok(()) +} + +struct HashablePyObject { + obj: PyObject, + hash: isize, +} + +impl HashablePyObject { + pub fn new(obj: &PyAny) -> Result { + let hash = obj.hash()?; + + Ok(HashablePyObject { + obj: obj.to_object(obj.py()), + hash, + }) + } +} + +impl IntoPy for &HashablePyObject { + fn into_py(self, _: Python<'_>) -> PyObject { + self.obj.clone() + } +} + +impl Hash for HashablePyObject { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +impl PartialEq for HashablePyObject { + fn eq(&self, other: &Self) -> bool { + let equal = Python::with_gil(|py| { + let result = self.obj.as_ref(py).eq(other.obj.as_ref(py)); + result.unwrap_or(false) + }); + + equal + } +} + +impl Eq for HashablePyObject {} + +#[pyclass] +struct PythonTreeCache(TreeCache); + +#[pymethods] +impl PythonTreeCache { + #[new] + fn new() -> Self { + PythonTreeCache(Default::default()) + } + + pub fn set(&mut self, key: &PyAny, value: PyObject) -> Result<(), Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + self.0.set(v, value)?; + + Ok(()) + } + + // pub fn get_node(&self, key: &PyAny) -> Result>, Error> { + // todo!() + // } + + pub fn get(&self, key: &PyAny) -> Result, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + Ok(self.0.get(&v)?) + } + + // pub fn pop_node(&mut self, key: &PyAny) -> Result>, Error> { + // todo!() + // } + + pub fn pop(&mut self, key: &PyAny) -> Result, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + Ok(self.0.pop(&v)?) + } + + pub fn clear(&mut self) { + self.0.clear() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn values(&self) -> Vec<&PyObject> { + self.0.values().collect() + } + + pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> { + todo!() + } +} diff --git a/rust/src/tree_cache/mod.rs b/rust/src/tree_cache/mod.rs new file mode 100644 index 0000000000..0a4905b881 --- /dev/null +++ b/rust/src/tree_cache/mod.rs @@ -0,0 +1,360 @@ +use std::{collections::HashMap, hash::Hash}; + +use anyhow::{bail, Error}; + +pub mod binding; + +pub enum TreeCacheNode { + Leaf(V), + Branch(usize, HashMap>), +} + +impl TreeCacheNode { + pub fn new_branch() -> Self { + TreeCacheNode::Branch(0, Default::default()) + } + + fn len(&self) -> usize { + match self { + TreeCacheNode::Leaf(_) => 1, + TreeCacheNode::Branch(size, _) => *size, + } + } +} + +impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode { + pub fn set( + &mut self, + mut key: impl Iterator, + value: V, + ) -> Result<(usize, usize), Error> { + if let Some(k) = key.next() { + match self { + TreeCacheNode::Leaf(_) => bail!("Given key is too long"), + TreeCacheNode::Branch(size, map) => { + let node = map.entry(k).or_insert_with(TreeCacheNode::new_branch); + let (added, removed) = node.set(key, value)?; + + *size += added; + *size -= removed; + + Ok((added, removed)) + } + } + } else { + let added = if let TreeCacheNode::Branch(_, map) = self { + (1, map.len()) + } else { + (0, 0) + }; + + *self = TreeCacheNode::Leaf(value); + + Ok(added) + } + } + + pub fn pop( + &mut self, + current_key: &K, + mut next_keys: impl Iterator, + ) -> Result>, Error> { + if let Some(next_key) = next_keys.next() { + match self { + TreeCacheNode::Leaf(_) => bail!("Given key is too long"), + TreeCacheNode::Branch(size, map) => { + let node = if let Some(node) = map.get_mut(current_key) { + node + } else { + return Ok(None); + }; + + if let Some(popped) = node.pop(next_key, next_keys)? { + *size -= node.len(); + + Ok(Some(popped)) + } else { + Ok(None) + } + } + } + } else { + match self { + TreeCacheNode::Leaf(_) => bail!("Given key is too long"), + TreeCacheNode::Branch(size, map) => { + if let Some(node) = map.remove(current_key) { + *size -= node.len(); + + Ok(Some(node)) + } else { + Ok(None) + } + } + } + } + } + + pub fn items(&self) -> impl Iterator, &V)> { + let mut stack = vec![(vec![], self)]; + + std::iter::from_fn(move || { + while let Some((prefix, node)) = stack.pop() { + match node { + TreeCacheNode::Leaf(value) => return Some((prefix, value)), + TreeCacheNode::Branch(_, map) => { + stack.extend(map.iter().map(|(k, v)| { + let mut prefix = prefix.clone(); + prefix.push(k); + (prefix, v) + })); + } + } + } + + None + }) + } +} + +impl Default for TreeCacheNode { + fn default() -> Self { + TreeCacheNode::new_branch() + } +} + +pub struct TreeCache { + root: TreeCacheNode, +} + +impl TreeCache { + pub fn new() -> Self { + TreeCache { + root: TreeCacheNode::new_branch(), + } + } +} + +impl<'a, K: Eq + Hash + 'a, V> TreeCache { + pub fn set(&mut self, key: impl IntoIterator, value: V) -> Result<(), Error> { + self.root.set(key.into_iter(), value)?; + + Ok(()) + } + + pub fn get_node( + &self, + key: impl IntoIterator, + ) -> Result>, Error> { + let mut node = &self.root; + + for k in key { + match node { + TreeCacheNode::Leaf(_) => bail!("Given key is too long"), + TreeCacheNode::Branch(_, map) => { + node = if let Some(node) = map.get(k) { + node + } else { + return Ok(None); + }; + } + } + } + + Ok(Some(node)) + } + + pub fn get(&self, key: impl IntoIterator) -> Result, Error> { + if let Some(node) = self.get_node(key)? { + match node { + TreeCacheNode::Leaf(value) => Ok(Some(value)), + TreeCacheNode::Branch(_, _) => bail!("Given key is too short"), + } + } else { + Ok(None) + } + } + + pub fn pop_node( + &mut self, + key: impl IntoIterator, + ) -> Result>, Error> { + let mut key_iter = key.into_iter(); + + let k = if let Some(k) = key_iter.next() { + k + } else { + let node = std::mem::replace(&mut self.root, TreeCacheNode::new_branch()); + return Ok(Some(node)); + }; + + self.root.pop(k, key_iter) + } + + pub fn pop(&mut self, key: &[K]) -> Result, Error> { + if let Some(node) = self.pop_node(key)? { + match node { + TreeCacheNode::Leaf(value) => Ok(Some(value)), + TreeCacheNode::Branch(_, _) => bail!("Given key is too short"), + } + } else { + Ok(None) + } + } + + pub fn clear(&mut self) { + self.root = TreeCacheNode::new_branch(); + } + + pub fn len(&self) -> usize { + match self.root { + TreeCacheNode::Leaf(_) => 1, + TreeCacheNode::Branch(size, _) => size, + } + } + + pub fn values(&self) -> impl Iterator { + let mut stack = vec![&self.root]; + + std::iter::from_fn(move || { + while let Some(node) = stack.pop() { + match node { + TreeCacheNode::Leaf(value) => return Some(value), + TreeCacheNode::Branch(_, map) => { + stack.extend(map.values()); + } + } + } + + None + }) + } + + pub fn items(&self) -> impl Iterator, &V)> { + self.root.items() + } +} + +impl Default for TreeCache { + fn default() -> Self { + TreeCache::new() + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeSet; + + use super::*; + + #[test] + fn get_set() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + + assert_eq!(cache.get(&["a", "b"])?, Some(&"c")); + + let node = cache.get_node(&["a"])?.unwrap(); + + match node { + TreeCacheNode::Leaf(_) => bail!("expected branch"), + TreeCacheNode::Branch(_, map) => { + assert_eq!(map.len(), 1); + assert!(map.contains_key("b")); + } + } + + Ok(()) + } + + #[test] + fn length() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + + assert_eq!(cache.len(), 1); + + cache.set(vec!["a", "b"], "d")?; + + assert_eq!(cache.len(), 1); + + cache.set(vec!["e", "f"], "g")?; + + assert_eq!(cache.len(), 2); + + cache.set(vec!["e", "h"], "i")?; + + assert_eq!(cache.len(), 3); + + cache.set(vec!["e"], "i")?; + + assert_eq!(cache.len(), 2); + + cache.pop_node(&["a"])?; + + assert_eq!(cache.len(), 1); + + Ok(()) + } + + #[test] + fn clear() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + + assert_eq!(cache.len(), 1); + + cache.clear(); + + assert_eq!(cache.len(), 0); + + assert_eq!(cache.get(&["a", "b"])?, None); + + Ok(()) + } + + #[test] + fn pop() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + assert_eq!(cache.pop(&["a", "b"])?, Some("c")); + assert_eq!(cache.pop(&["a", "b"])?, None); + + Ok(()) + } + + #[test] + fn values() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + + let expected = ["c"].iter().collect(); + assert_eq!(cache.values().collect::>(), expected); + + cache.set(vec!["d", "e"], "f")?; + + let expected = ["c", "f"].iter().collect(); + assert_eq!(cache.values().collect::>(), expected); + + Ok(()) + } + + #[test] + fn items() -> Result<(), Error> { + let mut cache = TreeCache::new(); + + cache.set(vec!["a", "b"], "c")?; + cache.set(vec!["d", "e"], "f")?; + + let expected = [(vec![&"a", &"b"], &"c"), (vec![&"d", &"e"], &"f")] + .into_iter() + .collect(); + assert_eq!(cache.items().collect::>(), expected); + + Ok(()) + } +} -- cgit 1.5.1