summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-12-05 14:03:28 +0000
committerErik Johnston <erik@matrix.org>2022-12-05 14:03:28 +0000
commit18ac015ecdd66fde7275efaaf07995b35ad4c41b (patch)
treec7a65dfa9770af82498095102d2ce1223de5b449
parentAdd tree cache (diff)
downloadsynapse-18ac015ecdd66fde7275efaaf07995b35ad4c41b.tar.xz
bindings
-rw-r--r--rust/src/lib.rs1
-rw-r--r--rust/src/tree_cache/binding.rs128
-rw-r--r--rust/src/tree_cache/mod.rs (renamed from rust/src/tree_cache.rs)18
3 files changed, 146 insertions, 1 deletions
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 00f72dc59f..6db2b1eae2 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -27,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
 
     push::register_module(py, m)?;
+    tree_cache::binding::register_module(py, m)?;
 
     Ok(())
 }
diff --git a/rust/src/tree_cache/binding.rs b/rust/src/tree_cache/binding.rs
new file mode 100644
index 0000000000..70207f8781
--- /dev/null
+++ b/rust/src/tree_cache/binding.rs
@@ -0,0 +1,128 @@
+use std::hash::Hash;
+
+use anyhow::Error;
+use pyo3::{
+    pyclass, pymethods, types::PyModule, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
+};
+
+use super::TreeCache;
+
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+    let child_module = PyModule::new(py, "tree_cache")?;
+    child_module.add_class::<PythonTreeCache>()?;
+
+    m.add_submodule(child_module)?;
+
+    // We need to manually add the module to sys.modules to make `from
+    // synapse.synapse_rust import push` work.
+    py.import("sys")?
+        .getattr("modules")?
+        .set_item("synapse.synapse_rust.tree_cache", child_module)?;
+
+    Ok(())
+}
+
+struct HashablePyObject {
+    obj: PyObject,
+    hash: isize,
+}
+
+impl HashablePyObject {
+    pub fn new(obj: &PyAny) -> Result<Self, Error> {
+        let hash = obj.hash()?;
+
+        Ok(HashablePyObject {
+            obj: obj.to_object(obj.py()),
+            hash,
+        })
+    }
+}
+
+impl IntoPy<PyObject> for &HashablePyObject {
+    fn into_py(self, _: Python<'_>) -> PyObject {
+        self.obj.clone()
+    }
+}
+
+impl Hash for HashablePyObject {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.hash.hash(state);
+    }
+}
+
+impl PartialEq for HashablePyObject {
+    fn eq(&self, other: &Self) -> bool {
+        let equal = Python::with_gil(|py| {
+            let result = self.obj.as_ref(py).eq(other.obj.as_ref(py));
+            result.unwrap_or(false)
+        });
+
+        equal
+    }
+}
+
+impl Eq for HashablePyObject {}
+
+#[pyclass]
+struct PythonTreeCache(TreeCache<HashablePyObject, PyObject>);
+
+#[pymethods]
+impl PythonTreeCache {
+    #[new]
+    fn new() -> Self {
+        PythonTreeCache(Default::default())
+    }
+
+    pub fn set(&mut self, key: &PyAny, value: PyObject) -> Result<(), Error> {
+        let v: Vec<HashablePyObject> = key
+            .iter()?
+            .map(|obj| HashablePyObject::new(obj?))
+            .collect::<Result<_, _>>()?;
+
+        self.0.set(v, value)?;
+
+        Ok(())
+    }
+
+    // pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> {
+    //     todo!()
+    // }
+
+    pub fn get(&self, key: &PyAny) -> Result<Option<&PyObject>, Error> {
+        let v: Vec<HashablePyObject> = key
+            .iter()?
+            .map(|obj| HashablePyObject::new(obj?))
+            .collect::<Result<_, _>>()?;
+
+        Ok(self.0.get(&v)?)
+    }
+
+    // pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> {
+    //     todo!()
+    // }
+
+    pub fn pop(&mut self, key: &PyAny) -> Result<Option<PyObject>, Error> {
+        let v: Vec<HashablePyObject> = key
+            .iter()?
+            .map(|obj| HashablePyObject::new(obj?))
+            .collect::<Result<_, _>>()?;
+
+        Ok(self.0.pop(&v)?)
+    }
+
+    pub fn clear(&mut self) {
+        self.0.clear()
+    }
+
+    pub fn len(&self) -> usize {
+        self.0.len()
+    }
+
+    pub fn values(&self) -> Vec<&PyObject> {
+        self.0.values().collect()
+    }
+
+    pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> {
+        todo!()
+    }
+}
diff --git a/rust/src/tree_cache.rs b/rust/src/tree_cache/mod.rs
index 6796229d64..0a4905b881 100644
--- a/rust/src/tree_cache.rs
+++ b/rust/src/tree_cache/mod.rs
@@ -2,6 +2,8 @@ use std::{collections::HashMap, hash::Hash};
 
 use anyhow::{bail, Error};
 
+pub mod binding;
+
 pub enum TreeCacheNode<K, V> {
     Leaf(V),
     Branch(usize, HashMap<K, TreeCacheNode<K, V>>),
@@ -114,17 +116,25 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
     }
 }
 
+impl<K, V> Default for TreeCacheNode<K, V> {
+    fn default() -> Self {
+        TreeCacheNode::new_branch()
+    }
+}
+
 pub struct TreeCache<K, V> {
     root: TreeCacheNode<K, V>,
 }
 
-impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
+impl<K, V> TreeCache<K, V> {
     pub fn new() -> Self {
         TreeCache {
             root: TreeCacheNode::new_branch(),
         }
     }
+}
 
+impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
     pub fn set(&mut self, key: impl IntoIterator<Item = K>, value: V) -> Result<(), Error> {
         self.root.set(key.into_iter(), value)?;
 
@@ -224,6 +234,12 @@ impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
     }
 }
 
+impl<K, V> Default for TreeCache<K, V> {
+    fn default() -> Self {
+        TreeCache::new()
+    }
+}
+
 #[cfg(test)]
 mod test {
     use std::collections::BTreeSet;