summary refs log tree commit diff
path: root/rust/src/tree_cache
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-12-24 12:36:54 +0000
committerErik Johnston <erik@matrix.org>2022-12-24 12:36:54 +0000
commitb9cdf3d85e0ef28cf89a1a307a984879ba604511 (patch)
treeaff7a2311328ce2fd18158b6ca302f251e20668a /rust/src/tree_cache
parentbindings (diff)
downloadsynapse-b9cdf3d85e0ef28cf89a1a307a984879ba604511.tar.xz
String cache
Diffstat (limited to 'rust/src/tree_cache')
-rw-r--r--rust/src/tree_cache/binding.rs64
-rw-r--r--rust/src/tree_cache/mod.rs20
2 files changed, 77 insertions, 7 deletions
diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs
index 70207f8781..e01601daf5 100644
--- a/rust/src/tree_cache/binding.rs
+++ b/rust/src/tree_cache/binding.rs
@@ -2,7 +2,9 @@ use std::hash::Hash;
 
 use anyhow::Error;
 use pyo3::{
-    pyclass, pymethods, types::PyModule, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
+    pyclass, pymethods,
+    types::{PyIterator, PyModule},
+    IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
 };
 
 use super::TreeCache;
@@ -10,6 +12,7 @@ use super::TreeCache;
 pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     let child_module = PyModule::new(py, "tree_cache")?;
     child_module.add_class::<PythonTreeCache>()?;
+    child_module.add_class::<StringTreeCache>()?;
 
     m.add_submodule(child_module)?;
 
@@ -126,3 +129,62 @@ impl PythonTreeCache {
         todo!()
     }
 }
+
+#[pyclass]
+struct StringTreeCache(TreeCache<String, String>);
+
+#[pymethods]
+impl StringTreeCache {
+    #[new]
+    fn new() -> Self {
+        StringTreeCache(Default::default())
+    }
+
+    pub fn set(&mut self, key: &PyAny, value: String) -> Result<(), Error> {
+        let key = key
+            .iter()?
+            .map(|o| o.expect("iter failed").extract().expect("not a string"));
+
+        self.0.set(key, value)?;
+
+        Ok(())
+    }
+
+    // pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> {
+    //     todo!()
+    // }
+
+    pub fn get(&self, key: &PyAny) -> Result<Option<&String>, Error> {
+        let key = key.iter()?.map(|o| {
+            o.expect("iter failed")
+                .extract::<String>()
+                .expect("not a string")
+        });
+
+        Ok(self.0.get(key)?)
+    }
+
+    // pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> {
+    //     todo!()
+    // }
+
+    pub fn pop(&mut self, key: Vec<String>) -> Result<Option<String>, Error> {
+        Ok(self.0.pop(&key)?)
+    }
+
+    pub fn clear(&mut self) {
+        self.0.clear()
+    }
+
+    pub fn len(&self) -> usize {
+        self.0.len()
+    }
+
+    pub fn values(&self) -> Vec<&String> {
+        self.0.values().collect()
+    }
+
+    pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> {
+        todo!()
+    }
+}
diff --git a/rust/src/tree_cache/mod.rs b/rust/src/tree_cache/mod.rs
index 0a4905b881..719d0b2cf9 100644
--- a/rust/src/tree_cache/mod.rs
+++ b/rust/src/tree_cache/mod.rs
@@ -1,4 +1,4 @@
-use std::{collections::HashMap, hash::Hash};
+use std::{borrow::Borrow, collections::HashMap, hash::Hash};
 
 use anyhow::{bail, Error};
 
@@ -141,17 +141,21 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
         Ok(())
     }
 
-    pub fn get_node(
+    pub fn get_node<Q>(
         &self,
-        key: impl IntoIterator<Item = &'a K>,
-    ) -> Result<Option<&TreeCacheNode<K, V>>, Error> {
+        key: impl IntoIterator<Item = Q>,
+    ) -> Result<Option<&TreeCacheNode<K, V>>, Error>
+    where
+        Q: Borrow<K>,
+        Q: Hash + Eq + 'a,
+    {
         let mut node = &self.root;
 
         for k in key {
             match node {
                 TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
                 TreeCacheNode::Branch(_, map) => {
-                    node = if let Some(node) = map.get(k) {
+                    node = if let Some(node) = map.get(k.borrow()) {
                         node
                     } else {
                         return Ok(None);
@@ -163,7 +167,11 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
         Ok(Some(node))
     }
 
-    pub fn get(&self, key: impl IntoIterator<Item = &'a K>) -> Result<Option<&V>, Error> {
+    pub fn get<Q>(&self, key: impl IntoIterator<Item = Q>) -> Result<Option<&V>, Error>
+    where
+        Q: Borrow<K>,
+        Q: Hash + Eq + 'a,
+    {
         if let Some(node) = self.get_node(key)? {
             match node {
                 TreeCacheNode::Leaf(value) => Ok(Some(value)),