1use crate::lean_imt::{LeanIMT, LeanIMTError, MerkleProof};
6
7pub trait LeanIMTHasher<const N: usize> {
9 fn hash(input: &[u8]) -> [u8; N];
10}
11
12#[derive(Debug, Default, Clone, PartialEq, Eq)]
14pub struct HashedLeanIMT<const N: usize, H> {
15 tree: LeanIMT<N>,
17 hasher: H,
19}
20
21impl<const N: usize, H> HashedLeanIMT<N, H>
22where
23 H: LeanIMTHasher<N>,
24{
25 pub fn new(leaves: &[[u8; N]], hasher: H) -> Result<Self, LeanIMTError> {
27 let imt = LeanIMT::new(leaves, H::hash)?;
28
29 Ok(Self { tree: imt, hasher })
30 }
31
32 pub fn new_from_tree(tree: LeanIMT<N>, hasher: H) -> Self {
34 Self { tree, hasher }
35 }
36
37 pub fn insert(&mut self, leaf: &[u8; N]) {
39 self.tree.insert(leaf, H::hash)
40 }
41
42 pub fn insert_many(&mut self, leaves: &[[u8; N]]) -> Result<(), LeanIMTError> {
44 self.tree.insert_many(leaves, H::hash)
45 }
46
47 pub fn update(&mut self, index: usize, new_leaf: &[u8; N]) -> Result<(), LeanIMTError> {
49 self.tree.update(index, new_leaf, H::hash)
50 }
51
52 pub fn generate_proof(&self, index: usize) -> Result<MerkleProof<N>, LeanIMTError> {
54 self.tree.generate_proof(index)
55 }
56
57 pub fn verify_proof(proof: &MerkleProof<N>) -> bool {
59 LeanIMT::verify_proof(proof, H::hash)
60 }
61
62 pub fn leaves(&self) -> &[[u8; N]] {
64 self.tree.leaves()
65 }
66
67 pub fn size(&self) -> usize {
69 self.tree.size()
70 }
71
72 pub fn root(&self) -> Option<[u8; N]> {
74 self.tree.root()
75 }
76
77 pub fn depth(&self) -> usize {
79 self.tree.depth()
80 }
81
82 pub fn get_leaf(&self, index: usize) -> Result<[u8; N], LeanIMTError> {
84 self.tree.get_leaf(index)
85 }
86
87 pub fn get_node(&self, level: usize, index: usize) -> Result<[u8; N], LeanIMTError> {
89 self.tree.get_node(level, index)
90 }
91
92 pub fn index_of(&self, leaf: &[u8]) -> Option<usize> {
94 self.tree.index_of(leaf)
95 }
96
97 pub fn contains(&self, leaf: &[u8]) -> bool {
99 self.tree.contains(leaf)
100 }
101
102 pub fn tree(&self) -> &LeanIMT<N> {
104 &self.tree
105 }
106
107 pub fn hasher(&self) -> &H {
109 &self.hasher
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use std::collections::hash_map::DefaultHasher;
117 use std::hash::{Hash, Hasher};
118
119 struct SampleHasher;
120
121 impl LeanIMTHasher<32> for SampleHasher {
122 fn hash(input: &[u8]) -> [u8; 32] {
123 let mut hasher = DefaultHasher::new();
124 input.hash(&mut hasher);
125 let h = hasher.finish();
126 let mut result = [0u8; 32];
127 result[..8].copy_from_slice(&h.to_le_bytes());
128 result
129 }
130 }
131
132 #[test]
133 fn test_new_empty_tree() {
134 let tree = HashedLeanIMT::<32, SampleHasher>::new(&[], SampleHasher).unwrap();
135
136 assert_eq!(tree.size(), 0);
137 assert_eq!(tree.root(), None);
138 }
139
140 #[test]
141 fn test_insert_leaves() {
142 let mut tree = HashedLeanIMT::<32, SampleHasher>::new(&[], SampleHasher).unwrap();
143 let leaf1 = [1u8; 32];
144
145 tree.insert(&leaf1);
146
147 assert_eq!(tree.size(), 1);
148 assert!(tree.contains(&leaf1));
149
150 let leaf2 = [2u8; 32];
151 let leaf3 = [3u8; 32];
152
153 tree.insert_many(&[leaf2, leaf3]).unwrap();
154
155 assert_eq!(tree.size(), 3);
156 assert!(tree.contains(&leaf2));
157 assert!(tree.contains(&leaf3));
158 }
159
160 #[test]
161 fn test_update_leaf() {
162 let initial_leaves = [[0u8; 32], [1u8; 32]];
163 let mut tree =
164 HashedLeanIMT::<32, SampleHasher>::new(&initial_leaves, SampleHasher).unwrap();
165 let updated_leaf = [42u8; 32];
166
167 tree.update(0, &updated_leaf).unwrap();
168
169 assert!(!tree.contains(&[0u8; 32]));
170 assert!(tree.contains(&updated_leaf));
171
172 assert!(tree.contains(&[1u8; 32]));
173 }
174
175 #[test]
176 fn test_merkle_proof() {
177 let leaves = vec![[0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32]];
178 let tree = HashedLeanIMT::<32, SampleHasher>::new(&leaves, SampleHasher).unwrap();
179 let proof = tree.generate_proof(1).unwrap();
180
181 assert!(HashedLeanIMT::<32, SampleHasher>::verify_proof(&proof));
182 assert_eq!(proof.index, 1);
183 assert_eq!(proof.leaf, [1u8; 32]);
184 }
185
186 #[test]
187 fn test_index_of_and_get_leaf() {
188 let leaves = vec![[0u8; 32], [1u8; 32], [2u8; 32]];
189 let tree = HashedLeanIMT::<32, SampleHasher>::new(&leaves, SampleHasher).unwrap();
190
191 assert_eq!(tree.index_of(&[1u8; 32]), Some(1));
192 assert_eq!(tree.index_of(&[42u8; 32]), None);
193
194 assert_eq!(tree.get_leaf(1), Ok([1u8; 32]));
195 assert_eq!(tree.get_leaf(5), Err(LeanIMTError::IndexOutOfBounds));
196 }
197}