1#![allow(clippy::manual_div_ceil)]
9
10use thiserror::Error;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14#[cfg_attr(
15 feature = "serde",
16 derive(serde::Serialize, serde::Deserialize),
17 serde(bound(
18 serialize = "[u8; N]: serde::Serialize",
19 deserialize = "[u8; N]: serde::Deserialize<'de>"
20 ))
21)]
22pub struct LeanIMT<const N: usize> {
23 nodes: Vec<Vec<[u8; N]>>,
25}
26
27impl<const N: usize> Default for LeanIMT<N> {
28 fn default() -> Self {
29 Self {
30 nodes: vec![Vec::new()],
31 }
32 }
33}
34
35impl<const N: usize> LeanIMT<N> {
36 pub fn new(leaves: &[[u8; N]], hash: impl Fn(&[u8]) -> [u8; N]) -> Result<Self, LeanIMTError> {
38 let mut imt = Self::default();
39
40 match leaves.len() {
41 0 => {},
42 1 => imt.insert(&leaves[0], hash),
43 _ => imt.insert_many(leaves, hash)?,
44 }
45
46 Ok(imt)
47 }
48
49 pub fn insert(&mut self, leaf: &[u8; N], hash: impl Fn(&[u8]) -> [u8; N]) {
51 let mut depth = self.depth();
52
53 if self.size() + 1 > (1 << depth) {
55 self.nodes.push(Vec::new());
56 depth += 1;
57 }
58
59 let mut node = *leaf;
60 let mut index = self.size();
61
62 for level in &mut self.nodes {
63 if level.len() <= index {
65 level.push(node);
66 } else {
67 level[index] = node;
68 }
69
70 if index % 2 == 1 {
72 let mut hash_input = Vec::with_capacity(N * 2);
73
74 hash_input.extend_from_slice(&level[index - 1]);
76 hash_input.extend_from_slice(&node);
77
78 node = hash(&hash_input);
79 }
80
81 index >>= 1;
83 }
84
85 self.nodes[depth] = vec![node];
86 }
87
88 pub fn insert_many(
94 &mut self,
95 leaves: &[[u8; N]],
96 hash: impl Fn(&[u8]) -> [u8; N],
97 ) -> Result<(), LeanIMTError> {
98 if leaves.is_empty() {
99 return Err(LeanIMTError::EmptyBatchInsert);
100 }
101
102 let start_index = self.size();
103 self.nodes[0].extend_from_slice(leaves);
104
105 let required_depth = self.size().next_power_of_two().trailing_zeros() as usize;
107 while self.depth() < required_depth {
108 self.nodes.push(Vec::new());
109 }
110
111 let mut index = start_index / 2;
113 for level in 0..self.depth() {
114 let level_len = self.nodes[level].len();
115 let start_parent_idx = index;
116 let num_parents = (level_len + 1) / 2;
117
118 for parent_idx in start_parent_idx..num_parents {
120 let left_idx = parent_idx * 2;
121 let left = self.nodes[level][left_idx];
122
123 let parent = if left_idx + 1 < level_len {
124 let right = self.nodes[level][left_idx + 1];
126
127 let mut hash_input = Vec::with_capacity(2 * N);
128 hash_input.extend_from_slice(&left);
129 hash_input.extend_from_slice(&right);
130 hash(&hash_input)
131 } else {
132 left
134 };
135
136 let next_level = &mut self.nodes[level + 1];
138 if parent_idx < next_level.len() {
139 next_level[parent_idx] = parent;
140 } else {
141 next_level.push(parent);
142 }
143 }
144
145 index /= 2;
147 }
148
149 Ok(())
150 }
151
152 pub fn update(
154 &mut self,
155 mut index: usize,
156 new_leaf: &[u8; N],
157 hash: impl Fn(&[u8]) -> [u8; N],
158 ) -> Result<(), LeanIMTError> {
159 if index >= self.size() {
160 return Err(LeanIMTError::IndexOutOfBounds);
161 }
162
163 let mut node = *new_leaf;
164
165 let depth = self.depth();
166 for level in 0..depth {
167 self.nodes[level][index] = node;
168 if index & 1 != 0 {
169 let sibling = self.nodes[level][index - 1];
170 let mut hash_input = Vec::with_capacity(N * 2);
171 hash_input.extend_from_slice(&sibling);
172 hash_input.extend_from_slice(&node);
173 node = hash(&hash_input);
174 } else if let Some(sibling) = self.nodes[level].get(index + 1).copied() {
175 let mut hash_input = Vec::with_capacity(N * 2);
176 hash_input.extend_from_slice(&node);
177 hash_input.extend_from_slice(&sibling);
178 node = hash(&hash_input);
179 }
180 index >>= 1;
181 }
182
183 self.nodes[depth][0] = node;
184 Ok(())
185 }
186
187 pub fn generate_proof(&self, mut index: usize) -> Result<MerkleProof<N>, LeanIMTError> {
189 if index >= self.size() {
190 return Err(LeanIMTError::IndexOutOfBounds);
191 }
192
193 let leaf = self.leaves()[index];
194 let mut siblings = Vec::new();
195 let mut path = Vec::new();
196
197 for level in 0..self.depth() {
198 let is_right = index & 1 != 0;
199 let sibling_idx = if is_right { index - 1 } else { index + 1 };
200
201 if let Some(sibling) = self.nodes[level].get(sibling_idx).copied() {
202 path.push(is_right);
203 siblings.push(sibling);
204 }
205
206 index >>= 1;
207 }
208
209 let final_index = path
210 .iter()
211 .rev()
212 .fold(0, |acc, &is_right| (acc << 1) | is_right as usize);
213
214 Ok(MerkleProof {
215 root: self.nodes[self.depth()][0],
216 leaf,
217 index: final_index,
218 siblings,
219 })
220 }
221
222 pub fn verify_proof(proof: &MerkleProof<N>, hash: impl Fn(&[u8]) -> [u8; N]) -> bool {
224 let mut node = proof.leaf;
225
226 for (i, sibling) in proof.siblings.iter().enumerate() {
227 let mut hash_input = Vec::with_capacity(N * 2);
228
229 if (proof.index >> i) & 1 != 0 {
230 hash_input.extend_from_slice(sibling);
232 hash_input.extend_from_slice(&node);
233 } else {
234 hash_input.extend_from_slice(&node);
236 hash_input.extend_from_slice(sibling);
237 }
238
239 node = hash(&hash_input);
240 }
241
242 proof.root == node
243 }
244
245 pub fn leaves(&self) -> &[[u8; N]] {
247 if self.nodes.is_empty() {
248 &[]
249 } else {
250 &self.nodes[0]
251 }
252 }
253
254 pub fn size(&self) -> usize {
256 self.leaves().len()
257 }
258
259 pub fn root(&self) -> Option<[u8; N]> {
261 self.nodes.last()?.first().copied()
262 }
263
264 pub fn depth(&self) -> usize {
266 self.nodes.len().saturating_sub(1)
267 }
268
269 pub fn get_leaf(&self, index: usize) -> Result<[u8; N], LeanIMTError> {
271 self.leaves()
272 .get(index)
273 .copied()
274 .ok_or(LeanIMTError::IndexOutOfBounds)
275 }
276
277 pub fn nodes(&self) -> &[Vec<[u8; N]>] {
279 &self.nodes
280 }
281
282 pub fn get_node(&self, level: usize, index: usize) -> Result<[u8; N], LeanIMTError> {
284 let level_vec = self
285 .nodes
286 .get(level)
287 .ok_or(LeanIMTError::LevelOutOfBounds)?;
288
289 level_vec
290 .get(index)
291 .copied()
292 .ok_or(LeanIMTError::IndexOutOfBounds)
293 }
294
295 pub fn index_of(&self, leaf: &[u8]) -> Option<usize> {
297 self.leaves().iter().position(|x| x == leaf)
298 }
299
300 pub fn contains(&self, leaf: &[u8]) -> bool {
302 self.index_of(leaf).is_some()
303 }
304}
305
306#[derive(Debug, Clone, PartialEq, Eq)]
308#[cfg_attr(
309 feature = "serde",
310 derive(serde::Serialize, serde::Deserialize),
311 serde(bound(
312 serialize = "[u8; N]: serde::Serialize",
313 deserialize = "[u8; N]: serde::Deserialize<'de>"
314 ))
315)]
316pub struct MerkleProof<const N: usize> {
317 pub root: [u8; N],
319 pub leaf: [u8; N],
321 pub index: usize,
323 pub siblings: Vec<[u8; N]>,
325}
326
327#[derive(Error, Debug, PartialEq, Eq)]
328pub enum LeanIMTError {
329 #[error("Index out of bounds")]
330 IndexOutOfBounds,
331 #[error("Invalid leaf size")]
332 InvalidLeafSize,
333 #[error("Level out of bounds")]
334 LevelOutOfBounds,
335 #[error("Empty batch insert")]
336 EmptyBatchInsert,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use std::collections::hash_map::DefaultHasher;
343 use std::hash::{Hash, Hasher};
344
345 fn hash(input: &[u8]) -> [u8; 4] {
346 let mut hasher = DefaultHasher::new();
347
348 for byte in input {
349 byte.hash(&mut hasher);
350 }
351 let hash = hasher.finish();
352
353 let mut result = [0u8; 4];
354 result.copy_from_slice(&hash.to_le_bytes()[..4]);
355 result
356 }
357
358 fn u32_to_leaf(n: u32) -> [u8; 4] {
360 n.to_le_bytes()
361 }
362
363 fn generate_leaves(size: u32) -> Vec<[u8; 4]> {
365 (0..size).map(u32_to_leaf).collect()
366 }
367
368 #[test]
369 fn test_new_tree_empty() {
370 let leaves: Vec<[u8; 4]> = vec![];
371 let tree = LeanIMT::new(&leaves, hash).unwrap();
372
373 assert_eq!(tree.size(), 0);
374 assert_eq!(tree.root(), None);
375 assert_eq!(tree.depth(), 0);
376
377 let leaves: &[[u8; 4]] = tree.leaves();
378 let empty_leaves: &[[u8; 4]] = &[];
379 assert_eq!(leaves, empty_leaves);
380 }
381
382 #[test]
383 fn test_insert_single_leaf() {
384 let mut tree = LeanIMT::new(&[], hash).unwrap();
385 let leaf = u32_to_leaf(1);
386 tree.insert(&leaf, hash);
387
388 assert_eq!(tree.root(), Some(leaf));
389 assert_eq!(tree.size(), 1);
390 }
391
392 #[test]
393 fn test_insert_multiple_leaves() {
394 let leaves = generate_leaves(5);
395 let tree_from_batch = LeanIMT::new(&leaves, hash).unwrap();
396
397 let mut tree_iter = LeanIMT::new(&[], hash).unwrap();
399 for leaf in leaves.iter() {
400 tree_iter.insert(leaf, hash);
401 }
402
403 assert_eq!(tree_from_batch, tree_iter);
404 }
405
406 #[test]
407 fn test_index_of_and_contains() {
408 let leaves = generate_leaves(5);
409 let tree = LeanIMT::new(&leaves, hash).unwrap();
410
411 assert_eq!(tree.index_of(&u32_to_leaf(2)), Some(2));
412 assert!(tree.contains(&u32_to_leaf(2)));
413
414 assert_eq!(tree.index_of(&u32_to_leaf(999)), None);
415 assert!(!tree.contains(&u32_to_leaf(999)));
416 }
417
418 #[test]
419 fn test_update_leaf() {
420 let leaves = generate_leaves(5);
421 let mut tree = LeanIMT::new(&leaves, hash).unwrap();
422
423 let new_leaf = u32_to_leaf(42);
424 tree.update(0, &new_leaf, hash).unwrap();
425 assert_eq!(tree.get_leaf(0).unwrap(), new_leaf);
426
427 let proof = tree.generate_proof(0).unwrap();
428 assert!(LeanIMT::verify_proof(&proof, hash));
429 }
430
431 #[test]
432 fn test_generate_and_verify_proof() {
433 let leaves = generate_leaves(5);
434 let tree = LeanIMT::new(&leaves, hash).unwrap();
435
436 for i in 0..leaves.len() {
437 let proof = tree.generate_proof(i).unwrap();
438 assert_eq!(proof.leaf, leaves[i]);
439 assert_eq!(proof.root, tree.root().unwrap());
440 assert!(LeanIMT::verify_proof(&proof, hash));
441 }
442 }
443
444 #[test]
445 fn test_generate_proof_invalid_index() {
446 let leaves = generate_leaves(5);
447 let tree = LeanIMT::new(&leaves, hash).unwrap();
448
449 let result = tree.generate_proof(999);
450 assert!(matches!(result, Err(LeanIMTError::IndexOutOfBounds)));
451 }
452
453 #[test]
454 fn test_update_invalid_index() {
455 let leaves = generate_leaves(5);
456 let mut tree = LeanIMT::new(&leaves, hash).unwrap();
457
458 let result = tree.update(100, &u32_to_leaf(10), hash);
459 assert!(matches!(result, Err(LeanIMTError::IndexOutOfBounds)));
460 }
461}