diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs
index 70207f8781..e01601daf5 100644
--- a/rust/src/tree_cache/binding.rs
+++ b/rust/src/tree_cache/binding.rs
@@ -2,7 +2,9 @@ use std::hash::Hash;
use anyhow::Error;
use pyo3::{
- pyclass, pymethods, types::PyModule, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
+ pyclass, pymethods,
+ types::{PyIterator, PyModule},
+ IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use super::TreeCache;
@@ -10,6 +12,7 @@ use super::TreeCache;
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "tree_cache")?;
child_module.add_class::<PythonTreeCache>()?;
+ child_module.add_class::<StringTreeCache>()?;
m.add_submodule(child_module)?;
@@ -126,3 +129,62 @@ impl PythonTreeCache {
todo!()
}
}
+
+#[pyclass]
+struct StringTreeCache(TreeCache<String, String>);
+
+#[pymethods]
+impl StringTreeCache {
+ #[new]
+ fn new() -> Self {
+ StringTreeCache(Default::default())
+ }
+
+ pub fn set(&mut self, key: &PyAny, value: String) -> Result<(), Error> {
+ let key = key
+ .iter()?
+ .map(|o| o.expect("iter failed").extract().expect("not a string"));
+
+ self.0.set(key, value)?;
+
+ Ok(())
+ }
+
+ // pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> {
+ // todo!()
+ // }
+
+ pub fn get(&self, key: &PyAny) -> Result<Option<&String>, Error> {
+ let key = key.iter()?.map(|o| {
+ o.expect("iter failed")
+ .extract::<String>()
+ .expect("not a string")
+ });
+
+ Ok(self.0.get(key)?)
+ }
+
+ // pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> {
+ // todo!()
+ // }
+
+ pub fn pop(&mut self, key: Vec<String>) -> Result<Option<String>, Error> {
+ Ok(self.0.pop(&key)?)
+ }
+
+ pub fn clear(&mut self) {
+ self.0.clear()
+ }
+
+ pub fn len(&self) -> usize {
+ self.0.len()
+ }
+
+ pub fn values(&self) -> Vec<&String> {
+ self.0.values().collect()
+ }
+
+ pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> {
+ todo!()
+ }
+}
diff --git a/rust/src/tree_cache/mod.rs b/rust/src/tree_cache/mod.rs
index 0a4905b881..719d0b2cf9 100644
--- a/rust/src/tree_cache/mod.rs
+++ b/rust/src/tree_cache/mod.rs
@@ -1,4 +1,4 @@
-use std::{collections::HashMap, hash::Hash};
+use std::{borrow::Borrow, collections::HashMap, hash::Hash};
use anyhow::{bail, Error};
@@ -141,17 +141,21 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
Ok(())
}
- pub fn get_node(
+ pub fn get_node<Q>(
&self,
- key: impl IntoIterator<Item = &'a K>,
- ) -> Result<Option<&TreeCacheNode<K, V>>, Error> {
+ key: impl IntoIterator<Item = Q>,
+ ) -> Result<Option<&TreeCacheNode<K, V>>, Error>
+ where
+ Q: Borrow<K>,
+ Q: Hash + Eq + 'a,
+ {
let mut node = &self.root;
for k in key {
match node {
TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
TreeCacheNode::Branch(_, map) => {
- node = if let Some(node) = map.get(k) {
+ node = if let Some(node) = map.get(k.borrow()) {
node
} else {
return Ok(None);
@@ -163,7 +167,11 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
Ok(Some(node))
}
- pub fn get(&self, key: impl IntoIterator<Item = &'a K>) -> Result<Option<&V>, Error> {
+ pub fn get<Q>(&self, key: impl IntoIterator<Item = Q>) -> Result<Option<&V>, Error>
+ where
+ Q: Borrow<K>,
+ Q: Hash + Eq + 'a,
+ {
if let Some(node) = self.get_node(key)? {
match node {
TreeCacheNode::Leaf(value) => Ok(Some(value)),
|