From c03785e121e4003433e21c43b057ecbb10751ee6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 24 Dec 2022 13:44:15 +0000 Subject: Implement {get,pop}_node --- rust/src/tree_cache/binding.rs | 71 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 7 deletions(-) (limited to 'rust/src/tree_cache/binding.rs') diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs index e01601daf5..6e274f5268 100644 --- a/rust/src/tree_cache/binding.rs +++ b/rust/src/tree_cache/binding.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use anyhow::Error; use pyo3::{ pyclass, pymethods, - types::{PyIterator, PyModule}, + types::{PyModule, PyTuple}, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, }; @@ -25,6 +25,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[derive(Clone)] struct HashablePyObject { obj: PyObject, hash: isize, @@ -41,12 +42,24 @@ impl HashablePyObject { } } +impl IntoPy for HashablePyObject { + fn into_py(self, _: Python<'_>) -> PyObject { + self.obj.clone() + } +} + impl IntoPy for &HashablePyObject { fn into_py(self, _: Python<'_>) -> PyObject { self.obj.clone() } } +impl ToPyObject for HashablePyObject { + fn to_object(&self, _py: Python<'_>) -> PyObject { + self.obj.clone() + } +} + impl Hash for HashablePyObject { fn hash(&self, state: &mut H) { self.hash.hash(state); @@ -87,9 +100,31 @@ impl PythonTreeCache { Ok(()) } - // pub fn get_node(&self, key: &PyAny) -> Result>, Error> { - // todo!() - // } + pub fn get_node<'a>( + &'a self, + py: Python<'a>, + key: &'a PyAny, + ) -> Result>, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + let Some(node) = self.0.get_node(v.clone())? else { + return Ok(None) + }; + + let items = node + .items() + .map(|(k, value)| { + let vec = v.iter().chain(k.iter().map(|a| *a)).collect::>(); + let nk = PyTuple::new(py, vec); + (nk, value) + }) + .collect::>(); + + Ok(Some(items)) + } pub fn get(&self, key: &PyAny) -> Result, Error> { let v: Vec = key @@ -100,9 +135,31 @@ impl PythonTreeCache { Ok(self.0.get(&v)?) } - // pub fn pop_node(&mut self, key: &PyAny) -> Result>, Error> { - // todo!() - // } + pub fn pop_node<'a>( + &'a mut self, + py: Python<'a>, + key: &'a PyAny, + ) -> Result>, Error> { + let v: Vec = key + .iter()? + .map(|obj| HashablePyObject::new(obj?)) + .collect::>()?; + + let Some(node) = self.0.pop_node(v.clone())? else { + return Ok(None) + }; + + let items = node + .into_items() + .map(|(k, value)| { + let vec = v.iter().chain(k.iter()).collect::>(); + let nk = PyTuple::new(py, vec); + (nk, value) + }) + .collect::>(); + + Ok(Some(items)) + } pub fn pop(&mut self, key: &PyAny) -> Result, Error> { let v: Vec = key -- cgit 1.5.1