summary refs log tree commit diff
path: root/rust/src
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-12-24 13:44:15 +0000
committerErik Johnston <erik@matrix.org>2022-12-24 13:44:17 +0000
commitc03785e121e4003433e21c43b057ecbb10751ee6 (patch)
tree1e4d3e6c0b7e186b75a7706f338db144c4ab9ab8 /rust/src
parentString cache (diff)
downloadsynapse-c03785e121e4003433e21c43b057ecbb10751ee6.tar.xz
Implement {get,pop}_node
Diffstat (limited to 'rust/src')
-rw-r--r--rust/src/tree_cache/binding.rs71
-rw-r--r--rust/src/tree_cache/mod.rs53
2 files changed, 106 insertions, 18 deletions
diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs
index e01601daf5..6e274f5268 100644
--- a/rust/src/tree_cache/binding.rs
+++ b/rust/src/tree_cache/binding.rs
@@ -3,7 +3,7 @@ use std::hash::Hash;
 use anyhow::Error;
 use pyo3::{
     pyclass, pymethods,
-    types::{PyIterator, PyModule},
+    types::{PyModule, PyTuple},
     IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
 };
 
@@ -25,6 +25,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     Ok(())
 }
 
+#[derive(Clone)]
 struct HashablePyObject {
     obj: PyObject,
     hash: isize,
@@ -41,12 +42,24 @@ impl HashablePyObject {
     }
 }
 
+impl IntoPy<PyObject> for HashablePyObject {
+    fn into_py(self, _: Python<'_>) -> PyObject {
+        self.obj.clone()
+    }
+}
+
 impl IntoPy<PyObject> for &HashablePyObject {
     fn into_py(self, _: Python<'_>) -> PyObject {
         self.obj.clone()
     }
 }
 
+impl ToPyObject for HashablePyObject {
+    fn to_object(&self, _py: Python<'_>) -> PyObject {
+        self.obj.clone()
+    }
+}
+
 impl Hash for HashablePyObject {
     fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
         self.hash.hash(state);
@@ -87,9 +100,31 @@ impl PythonTreeCache {
         Ok(())
     }
 
-    // pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> {
-    //     todo!()
-    // }
+    pub fn get_node<'a>(
+        &'a self,
+        py: Python<'a>,
+        key: &'a PyAny,
+    ) -> Result<Option<Vec<(&'a PyTuple, &'a PyObject)>>, Error> {
+        let v: Vec<HashablePyObject> = key
+            .iter()?
+            .map(|obj| HashablePyObject::new(obj?))
+            .collect::<Result<_, _>>()?;
+
+        let Some(node) = self.0.get_node(v.clone())? else {
+            return Ok(None)
+        };
+
+        let items = node
+            .items()
+            .map(|(k, value)| {
+                let vec = v.iter().chain(k.iter().map(|a| *a)).collect::<Vec<_>>();
+                let nk = PyTuple::new(py, vec);
+                (nk, value)
+            })
+            .collect::<Vec<_>>();
+
+        Ok(Some(items))
+    }
 
     pub fn get(&self, key: &PyAny) -> Result<Option<&PyObject>, Error> {
         let v: Vec<HashablePyObject> = key
@@ -100,9 +135,31 @@ impl PythonTreeCache {
         Ok(self.0.get(&v)?)
     }
 
-    // pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> {
-    //     todo!()
-    // }
+    pub fn pop_node<'a>(
+        &'a mut self,
+        py: Python<'a>,
+        key: &'a PyAny,
+    ) -> Result<Option<Vec<(&'a PyTuple, PyObject)>>, Error> {
+        let v: Vec<HashablePyObject> = key
+            .iter()?
+            .map(|obj| HashablePyObject::new(obj?))
+            .collect::<Result<_, _>>()?;
+
+        let Some(node) = self.0.pop_node(v.clone())? else {
+            return Ok(None)
+        };
+
+        let items = node
+            .into_items()
+            .map(|(k, value)| {
+                let vec = v.iter().chain(k.iter()).collect::<Vec<_>>();
+                let nk = PyTuple::new(py, vec);
+                (nk, value)
+            })
+            .collect::<Vec<_>>();
+
+        Ok(Some(items))
+    }
 
     pub fn pop(&mut self, key: &PyAny) -> Result<Option<PyObject>, Error> {
         let v: Vec<HashablePyObject> = key
diff --git a/rust/src/tree_cache/mod.rs b/rust/src/tree_cache/mod.rs
index 719d0b2cf9..cb45e1f70d 100644
--- a/rust/src/tree_cache/mod.rs
+++ b/rust/src/tree_cache/mod.rs
@@ -54,16 +54,20 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
         }
     }
 
-    pub fn pop(
+    pub fn pop<Q>(
         &mut self,
-        current_key: &K,
-        mut next_keys: impl Iterator<Item = &'a K>,
-    ) -> Result<Option<TreeCacheNode<K, V>>, Error> {
+        current_key: Q,
+        mut next_keys: impl Iterator<Item = Q>,
+    ) -> Result<Option<TreeCacheNode<K, V>>, Error>
+    where
+        Q: Borrow<K>,
+        Q: Hash + Eq + 'a,
+    {
         if let Some(next_key) = next_keys.next() {
             match self {
                 TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
                 TreeCacheNode::Branch(size, map) => {
-                    let node = if let Some(node) = map.get_mut(current_key) {
+                    let node = if let Some(node) = map.get_mut(current_key.borrow()) {
                         node
                     } else {
                         return Ok(None);
@@ -82,7 +86,7 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
             match self {
                 TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
                 TreeCacheNode::Branch(size, map) => {
-                    if let Some(node) = map.remove(current_key) {
+                    if let Some(node) = map.remove(current_key.borrow()) {
                         *size -= node.len();
 
                         Ok(Some(node))
@@ -94,8 +98,8 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
         }
     }
 
-    pub fn items(&self) -> impl Iterator<Item = (Vec<&K>, &V)> {
-        let mut stack = vec![(vec![], self)];
+    pub fn items(&'a self) -> impl Iterator<Item = (Vec<&K>, &V)> {
+        let mut stack = vec![(Vec::new(), self)];
 
         std::iter::from_fn(move || {
             while let Some((prefix, node)) = stack.pop() {
@@ -116,6 +120,29 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
     }
 }
 
+impl<'a, K: Clone + Eq + Hash + 'a, V> TreeCacheNode<K, V> {
+    pub fn into_items(self) -> impl Iterator<Item = (Vec<K>, V)> {
+        let mut stack = vec![(Vec::new(), self)];
+
+        std::iter::from_fn(move || {
+            while let Some((prefix, node)) = stack.pop() {
+                match node {
+                    TreeCacheNode::Leaf(value) => return Some((prefix, value)),
+                    TreeCacheNode::Branch(_, map) => {
+                        stack.extend(map.into_iter().map(|(k, v)| {
+                            let mut prefix = prefix.clone();
+                            prefix.push(k);
+                            (prefix, v)
+                        }));
+                    }
+                }
+            }
+
+            None
+        })
+    }
+}
+
 impl<K, V> Default for TreeCacheNode<K, V> {
     fn default() -> Self {
         TreeCacheNode::new_branch()
@@ -182,10 +209,14 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
         }
     }
 
-    pub fn pop_node(
+    pub fn pop_node<Q>(
         &mut self,
-        key: impl IntoIterator<Item = &'a K>,
-    ) -> Result<Option<TreeCacheNode<K, V>>, Error> {
+        key: impl IntoIterator<Item = Q>,
+    ) -> Result<Option<TreeCacheNode<K, V>>, Error>
+    where
+        Q: Borrow<K>,
+        Q: Hash + Eq + 'a,
+    {
         let mut key_iter = key.into_iter();
 
         let k = if let Some(k) = key_iter.next() {