1use crate::lean_imt::{LeanIMT, LeanIMTError, MerkleProof};
6
7pub trait LeanIMTHasher<const N: usize> {
9 fn hash(input: &[u8]) -> [u8; N];
10}
11
12#[derive(Debug, Default, Clone, PartialEq, Eq)]
14pub struct HashedLeanIMT<const N: usize, H> {
15 tree: LeanIMT<N>,
17 hasher: H,
19}
20
21impl<const N: usize, H> HashedLeanIMT<N, H>
22where
23 H: LeanIMTHasher<N>,
24{
25 pub fn new(leaves: &[[u8; N]], hasher: H) -> Result<Self, LeanIMTError> {
27 let imt = LeanIMT::new(leaves, H::hash)?;
28
29 Ok(Self { tree: imt, hasher })
30 }
31
32 pub fn new_from_tree(tree: LeanIMT<N>, hasher: H) -> Self {
34 Self { tree, hasher }
35 }
36
37 pub fn insert(&mut self, leaf: &[u8; N]) {
39 self.tree.insert(leaf, H::hash)
40 }
41
42 pub fn insert_many(&mut self, leaves: &[[u8; N]]) -> Result<(), LeanIMTError> {
48 self.tree.insert_many(leaves, H::hash)
49 }
50
51 pub fn update(&mut self, index: usize, new_leaf: &[u8; N]) -> Result<(), LeanIMTError> {
53 self.tree.update(index, new_leaf, H::hash)
54 }
55
56 pub fn generate_proof(&self, index: usize) -> Result<MerkleProof<N>, LeanIMTError> {
58 self.tree.generate_proof(index)
59 }
60
61 pub fn verify_proof(proof: &MerkleProof<N>) -> bool {
63 LeanIMT::verify_proof(proof, H::hash)
64 }
65
66 pub fn leaves(&self) -> &[[u8; N]] {
68 self.tree.leaves()
69 }
70
71 pub fn size(&self) -> usize {
73 self.tree.size()
74 }
75
76 pub fn root(&self) -> Option<[u8; N]> {
78 self.tree.root()
79 }
80
81 pub fn depth(&self) -> usize {
83 self.tree.depth()
84 }
85
86 pub fn get_leaf(&self, index: usize) -> Result<[u8; N], LeanIMTError> {
88 self.tree.get_leaf(index)
89 }
90
91 pub fn get_node(&self, level: usize, index: usize) -> Result<[u8; N], LeanIMTError> {
93 self.tree.get_node(level, index)
94 }
95
96 pub fn index_of(&self, leaf: &[u8]) -> Option<usize> {
98 self.tree.index_of(leaf)
99 }
100
101 pub fn contains(&self, leaf: &[u8]) -> bool {
103 self.tree.contains(leaf)
104 }
105
106 pub fn tree(&self) -> &LeanIMT<N> {
108 &self.tree
109 }
110
111 pub fn hasher(&self) -> &H {
113 &self.hasher
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use std::collections::hash_map::DefaultHasher;
121 use std::hash::{Hash, Hasher};
122
123 struct SampleHasher;
124
125 impl LeanIMTHasher<32> for SampleHasher {
126 fn hash(input: &[u8]) -> [u8; 32] {
127 let mut hasher = DefaultHasher::new();
128 input.hash(&mut hasher);
129 let h = hasher.finish();
130 let mut result = [0u8; 32];
131 result[..8].copy_from_slice(&h.to_le_bytes());
132 result
133 }
134 }
135
136 #[test]
137 fn test_new_empty_tree() {
138 let tree = HashedLeanIMT::<32, SampleHasher>::new(&[], SampleHasher).unwrap();
139
140 assert_eq!(tree.size(), 0);
141 assert_eq!(tree.root(), None);
142 }
143
144 #[test]
145 fn test_insert_leaves() {
146 let mut tree = HashedLeanIMT::<32, SampleHasher>::new(&[], SampleHasher).unwrap();
147 let leaf1 = [1u8; 32];
148
149 tree.insert(&leaf1);
150
151 assert_eq!(tree.size(), 1);
152 assert!(tree.contains(&leaf1));
153
154 let leaf2 = [2u8; 32];
155 let leaf3 = [3u8; 32];
156
157 tree.insert_many(&[leaf2, leaf3]).unwrap();
158
159 assert_eq!(tree.size(), 3);
160 assert!(tree.contains(&leaf2));
161 assert!(tree.contains(&leaf3));
162 }
163
164 #[test]
165 fn test_update_leaf() {
166 let initial_leaves = [[0u8; 32], [1u8; 32]];
167 let mut tree =
168 HashedLeanIMT::<32, SampleHasher>::new(&initial_leaves, SampleHasher).unwrap();
169 let updated_leaf = [42u8; 32];
170
171 tree.update(0, &updated_leaf).unwrap();
172
173 assert!(!tree.contains(&[0u8; 32]));
174 assert!(tree.contains(&updated_leaf));
175
176 assert!(tree.contains(&[1u8; 32]));
177 }
178
179 #[test]
180 fn test_merkle_proof() {
181 let leaves = vec![[0u8; 32], [1u8; 32], [2u8; 32], [3u8; 32]];
182 let tree = HashedLeanIMT::<32, SampleHasher>::new(&leaves, SampleHasher).unwrap();
183 let proof = tree.generate_proof(1).unwrap();
184
185 assert!(HashedLeanIMT::<32, SampleHasher>::verify_proof(&proof));
186 assert_eq!(proof.index, 1);
187 assert_eq!(proof.leaf, [1u8; 32]);
188 }
189
190 #[test]
191 fn test_index_of_and_get_leaf() {
192 let leaves = vec![[0u8; 32], [1u8; 32], [2u8; 32]];
193 let tree = HashedLeanIMT::<32, SampleHasher>::new(&leaves, SampleHasher).unwrap();
194
195 assert_eq!(tree.index_of(&[1u8; 32]), Some(1));
196 assert_eq!(tree.index_of(&[42u8; 32]), None);
197
198 assert_eq!(tree.get_leaf(1), Ok([1u8; 32]));
199 assert_eq!(tree.get_leaf(5), Err(LeanIMTError::IndexOutOfBounds));
200 }
201}