diff options
Diffstat (limited to 'rust/src/tree_cache')
-rw-r--r-- | rust/src/tree_cache/binding.rs | 71 | ||||
-rw-r--r-- | rust/src/tree_cache/mod.rs | 53 |
2 files changed, 106 insertions, 18 deletions
diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs index e01601daf5..6e274f5268 100644 --- a/rust/src/tree_cache/binding.rs +++ b/rust/src/tree_cache/binding.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use anyhow::Error; use pyo3::{ pyclass, pymethods, - types::{PyIterator, PyModule}, + types::{PyModule, PyTuple}, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, }; @@ -25,6 +25,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[derive(Clone)] struct HashablePyObject { obj: PyObject, hash: isize, @@ -41,12 +42,24 @@ impl HashablePyObject { } } +impl IntoPy<PyObject> for HashablePyObject { + fn into_py(self, _: Python<'_>) -> PyObject { + self.obj.clone() + } +} + impl IntoPy<PyObject> for &HashablePyObject { fn into_py(self, _: Python<'_>) -> PyObject { self.obj.clone() } } +impl ToPyObject for HashablePyObject { + fn to_object(&self, _py: Python<'_>) -> PyObject { + self.obj.clone() + } +} + impl Hash for HashablePyObject { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.hash.hash(state); @@ -87,9 +100,31 @@ impl PythonTreeCache { Ok(()) } - // pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> { - // todo!() - // } + pub fn get_node<'a>( + &'a self, + py: Python<'a>, + key: &'a PyAny, + ) -> Result<Option<Vec<(&'a PyTuple, &'a PyObject)>>, Error> { + let v: Vec<HashablePyObject> = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::<Result<_, _>>()?; + + let Some(node) = self.0.get_node(v.clone())? else { + return Ok(None) + }; + + let items = node + .items() + .map(|(k, value)| { + let vec = v.iter().chain(k.iter().map(|a| *a)).collect::<Vec<_>>(); + let nk = PyTuple::new(py, vec); + (nk, value) + }) + .collect::<Vec<_>>(); + + Ok(Some(items)) + } pub fn get(&self, key: &PyAny) -> Result<Option<&PyObject>, Error> { let v: Vec<HashablePyObject> = key @@ -100,9 +135,31 @@ impl PythonTreeCache { Ok(self.0.get(&v)?) } - // pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> { - // todo!() - // } + pub fn pop_node<'a>( + &'a mut self, + py: Python<'a>, + key: &'a PyAny, + ) -> Result<Option<Vec<(&'a PyTuple, PyObject)>>, Error> { + let v: Vec<HashablePyObject> = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::<Result<_, _>>()?; + + let Some(node) = self.0.pop_node(v.clone())? else { + return Ok(None) + }; + + let items = node + .into_items() + .map(|(k, value)| { + let vec = v.iter().chain(k.iter()).collect::<Vec<_>>(); + let nk = PyTuple::new(py, vec); + (nk, value) + }) + .collect::<Vec<_>>(); + + Ok(Some(items)) + } pub fn pop(&mut self, key: &PyAny) -> Result<Option<PyObject>, Error> { let v: Vec<HashablePyObject> = key diff --git a/rust/src/tree_cache/mod.rs b/rust/src/tree_cache/mod.rs index 719d0b2cf9..cb45e1f70d 100644 --- a/rust/src/tree_cache/mod.rs +++ b/rust/src/tree_cache/mod.rs @@ -54,16 +54,20 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> { } } - pub fn pop( + pub fn pop<Q>( &mut self, - current_key: &K, - mut next_keys: impl Iterator<Item = &'a K>, - ) -> Result<Option<TreeCacheNode<K, V>>, Error> { + current_key: Q, + mut next_keys: impl Iterator<Item = Q>, + ) -> Result<Option<TreeCacheNode<K, V>>, Error> + where + Q: Borrow<K>, + Q: Hash + Eq + 'a, + { if let Some(next_key) = next_keys.next() { match self { TreeCacheNode::Leaf(_) => bail!("Given key is too long"), TreeCacheNode::Branch(size, map) => { - let node = if let Some(node) = map.get_mut(current_key) { + let node = if let Some(node) = map.get_mut(current_key.borrow()) { node } else { return Ok(None); @@ -82,7 +86,7 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> { match self { TreeCacheNode::Leaf(_) => bail!("Given key is too long"), TreeCacheNode::Branch(size, map) => { - if let Some(node) = map.remove(current_key) { + if let Some(node) = map.remove(current_key.borrow()) { *size -= node.len(); Ok(Some(node)) @@ -94,8 +98,8 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> { } } - pub fn items(&self) -> impl Iterator<Item = (Vec<&K>, &V)> { - let mut stack = vec![(vec![], self)]; + pub fn items(&'a self) -> impl Iterator<Item = (Vec<&K>, &V)> { + let mut stack = vec![(Vec::new(), self)]; std::iter::from_fn(move || { while let Some((prefix, node)) = stack.pop() { @@ -116,6 +120,29 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> { } } +impl<'a, K: Clone + Eq + Hash + 'a, V> TreeCacheNode<K, V> { + pub fn into_items(self) -> impl Iterator<Item = (Vec<K>, V)> { + let mut stack = vec![(Vec::new(), self)]; + + std::iter::from_fn(move || { + while let Some((prefix, node)) = stack.pop() { + match node { + TreeCacheNode::Leaf(value) => return Some((prefix, value)), + TreeCacheNode::Branch(_, map) => { + stack.extend(map.into_iter().map(|(k, v)| { + let mut prefix = prefix.clone(); + prefix.push(k); + (prefix, v) + })); + } + } + } + + None + }) + } +} + impl<K, V> Default for TreeCacheNode<K, V> { fn default() -> Self { TreeCacheNode::new_branch() @@ -182,10 +209,14 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> { } } - pub fn pop_node( + pub fn pop_node<Q>( &mut self, - key: impl IntoIterator<Item = &'a K>, - ) -> Result<Option<TreeCacheNode<K, V>>, Error> { + key: impl IntoIterator<Item = Q>, + ) -> Result<Option<TreeCacheNode<K, V>>, Error> + where + Q: Borrow<K>, + Q: Hash + Eq + 'a, + { let mut key_iter = key.into_iter(); let k = if let Some(k) = key_iter.next() { |