lean_imt/
lean_imt.rs

1//! # LeanIMT
2//!
3//! Lean Incremental Merkle Tree implementation.
4//!
5//! Specifications can be found here:
6//!  - <https://github.com/privacy-scaling-explorations/zk-kit/blob/main/papers/leanimt/paper/leanimt-paper.pdf>
7
8#![allow(clippy::manual_div_ceil)]
9
10use thiserror::Error;
11
12/// LeanIMT struct.
13#[derive(Debug, Clone, PartialEq, Eq)]
14#[cfg_attr(
15    feature = "serde",
16    derive(serde::Serialize, serde::Deserialize),
17    serde(bound(
18        serialize = "[u8; N]: serde::Serialize",
19        deserialize = "[u8; N]: serde::Deserialize<'de>"
20    ))
21)]
22pub struct LeanIMT<const N: usize> {
23    /// Nodes storage.
24    nodes: Vec<Vec<[u8; N]>>,
25}
26
27impl<const N: usize> Default for LeanIMT<N> {
28    fn default() -> Self {
29        Self {
30            nodes: vec![Vec::new()],
31        }
32    }
33}
34
35impl<const N: usize> LeanIMT<N> {
36    /// Creates a new tree with optional initial leaves.
37    pub fn new(leaves: &[[u8; N]], hash: impl Fn(&[u8]) -> [u8; N]) -> Result<Self, LeanIMTError> {
38        let mut imt = Self::default();
39
40        match leaves.len() {
41            0 => {},
42            1 => imt.insert(&leaves[0], hash),
43            _ => imt.insert_many(leaves, hash)?,
44        }
45
46        Ok(imt)
47    }
48
49    /// Inserts a single leaf.
50    pub fn insert(&mut self, leaf: &[u8; N], hash: impl Fn(&[u8]) -> [u8; N]) {
51        let mut depth = self.depth();
52
53        // Expand capacity if exceeded.
54        if self.size() + 1 > (1 << depth) {
55            self.nodes.push(Vec::new());
56            depth += 1;
57        }
58
59        let mut node = *leaf;
60        let mut index = self.size();
61
62        for level in &mut self.nodes {
63            // If the level is smaller than the expected index, we push a node
64            if level.len() <= index {
65                level.push(node);
66            } else {
67                level[index] = node;
68            }
69
70            // If we are at an odd index, we hash the leaves.
71            if index % 2 == 1 {
72                let mut hash_input = Vec::with_capacity(N * 2);
73
74                // Sibling goes first.
75                hash_input.extend_from_slice(&level[index - 1]);
76                hash_input.extend_from_slice(&node);
77
78                node = hash(&hash_input);
79            }
80
81            // Divide the expected index by 2.
82            index >>= 1;
83        }
84
85        self.nodes[depth] = vec![node];
86    }
87
88    /// Inserts multiple leaves.
89    ///
90    /// # Errors
91    ///
92    /// Will return [`LeanIMTError::EmptyBatchInsert`] if `leaves` is an empty array
93    pub fn insert_many(
94        &mut self,
95        leaves: &[[u8; N]],
96        hash: impl Fn(&[u8]) -> [u8; N],
97    ) -> Result<(), LeanIMTError> {
98        if leaves.is_empty() {
99            return Err(LeanIMTError::EmptyBatchInsert);
100        }
101
102        let start_index = self.size();
103        self.nodes[0].extend_from_slice(leaves);
104
105        // Ensure the tree has enough levels
106        let required_depth = self.size().next_power_of_two().trailing_zeros() as usize;
107        while self.depth() < required_depth {
108            self.nodes.push(Vec::new());
109        }
110
111        // Start from level 0 and update parent nodes
112        let mut index = start_index / 2;
113        for level in 0..self.depth() {
114            let level_len = self.nodes[level].len();
115            let start_parent_idx = index;
116            let num_parents = (level_len + 1) / 2;
117
118            // Process each parent node starting from the affected index
119            for parent_idx in start_parent_idx..num_parents {
120                let left_idx = parent_idx * 2;
121                let left = self.nodes[level][left_idx];
122
123                let parent = if left_idx + 1 < level_len {
124                    // Node has both children, hash them
125                    let right = self.nodes[level][left_idx + 1];
126
127                    let mut hash_input = Vec::with_capacity(2 * N);
128                    hash_input.extend_from_slice(&left);
129                    hash_input.extend_from_slice(&right);
130                    hash(&hash_input)
131                } else {
132                    // Node has only left child, propagate it
133                    left
134                };
135
136                // Update or add parent node
137                let next_level = &mut self.nodes[level + 1];
138                if parent_idx < next_level.len() {
139                    next_level[parent_idx] = parent;
140                } else {
141                    next_level.push(parent);
142                }
143            }
144
145            // Update index for the next level
146            index /= 2;
147        }
148
149        Ok(())
150    }
151
152    /// Updates a leaf at the given index.
153    pub fn update(
154        &mut self,
155        mut index: usize,
156        new_leaf: &[u8; N],
157        hash: impl Fn(&[u8]) -> [u8; N],
158    ) -> Result<(), LeanIMTError> {
159        if index >= self.size() {
160            return Err(LeanIMTError::IndexOutOfBounds);
161        }
162
163        let mut node = *new_leaf;
164
165        let depth = self.depth();
166        for level in 0..depth {
167            self.nodes[level][index] = node;
168            if index & 1 != 0 {
169                let sibling = self.nodes[level][index - 1];
170                let mut hash_input = Vec::with_capacity(N * 2);
171                hash_input.extend_from_slice(&sibling);
172                hash_input.extend_from_slice(&node);
173                node = hash(&hash_input);
174            } else if let Some(sibling) = self.nodes[level].get(index + 1).copied() {
175                let mut hash_input = Vec::with_capacity(N * 2);
176                hash_input.extend_from_slice(&node);
177                hash_input.extend_from_slice(&sibling);
178                node = hash(&hash_input);
179            }
180            index >>= 1;
181        }
182
183        self.nodes[depth][0] = node;
184        Ok(())
185    }
186
187    /// Generates a Merkle proof for a leaf at the given index.
188    pub fn generate_proof(&self, mut index: usize) -> Result<MerkleProof<N>, LeanIMTError> {
189        if index >= self.size() {
190            return Err(LeanIMTError::IndexOutOfBounds);
191        }
192
193        let leaf = self.leaves()[index];
194        let mut siblings = Vec::new();
195        let mut path = Vec::new();
196
197        for level in 0..self.depth() {
198            let is_right = index & 1 != 0;
199            let sibling_idx = if is_right { index - 1 } else { index + 1 };
200
201            if let Some(sibling) = self.nodes[level].get(sibling_idx).copied() {
202                path.push(is_right);
203                siblings.push(sibling);
204            }
205
206            index >>= 1;
207        }
208
209        let final_index = path
210            .iter()
211            .rev()
212            .fold(0, |acc, &is_right| (acc << 1) | is_right as usize);
213
214        Ok(MerkleProof {
215            root: self.nodes[self.depth()][0],
216            leaf,
217            index: final_index,
218            siblings,
219        })
220    }
221
222    /// Verifies a Merkle proof.
223    pub fn verify_proof(proof: &MerkleProof<N>, hash: impl Fn(&[u8]) -> [u8; N]) -> bool {
224        let mut node = proof.leaf;
225
226        for (i, sibling) in proof.siblings.iter().enumerate() {
227            let mut hash_input = Vec::with_capacity(N * 2);
228
229            if (proof.index >> i) & 1 != 0 {
230                // Right node
231                hash_input.extend_from_slice(sibling);
232                hash_input.extend_from_slice(&node);
233            } else {
234                // Left node
235                hash_input.extend_from_slice(&node);
236                hash_input.extend_from_slice(sibling);
237            }
238
239            node = hash(&hash_input);
240        }
241
242        proof.root == node
243    }
244
245    /// Returns the leaves.
246    pub fn leaves(&self) -> &[[u8; N]] {
247        if self.nodes.is_empty() {
248            &[]
249        } else {
250            &self.nodes[0]
251        }
252    }
253
254    /// Returns the number of leaves in the tree.
255    pub fn size(&self) -> usize {
256        self.leaves().len()
257    }
258
259    /// Returns the tree root, if it exists.
260    pub fn root(&self) -> Option<[u8; N]> {
261        self.nodes.last()?.first().copied()
262    }
263
264    /// Returns the tree depth.
265    pub fn depth(&self) -> usize {
266        self.nodes.len().saturating_sub(1)
267    }
268
269    /// Retrieves a leaf at the given index.
270    pub fn get_leaf(&self, index: usize) -> Result<[u8; N], LeanIMTError> {
271        self.leaves()
272            .get(index)
273            .copied()
274            .ok_or(LeanIMTError::IndexOutOfBounds)
275    }
276
277    /// Returns the internal nodes structure.
278    pub fn nodes(&self) -> &[Vec<[u8; N]>] {
279        &self.nodes
280    }
281
282    /// Retrieves the node at a specified level and index.
283    pub fn get_node(&self, level: usize, index: usize) -> Result<[u8; N], LeanIMTError> {
284        let level_vec = self
285            .nodes
286            .get(level)
287            .ok_or(LeanIMTError::LevelOutOfBounds)?;
288
289        level_vec
290            .get(index)
291            .copied()
292            .ok_or(LeanIMTError::IndexOutOfBounds)
293    }
294
295    /// Finds the index of a given leaf, if it exists.
296    pub fn index_of(&self, leaf: &[u8]) -> Option<usize> {
297        self.leaves().iter().position(|x| x == leaf)
298    }
299
300    /// Checks whether the tree contains the specified leaf.
301    pub fn contains(&self, leaf: &[u8]) -> bool {
302        self.index_of(leaf).is_some()
303    }
304}
305
306/// Merkle proof.
307#[derive(Debug, Clone, PartialEq, Eq)]
308#[cfg_attr(
309    feature = "serde",
310    derive(serde::Serialize, serde::Deserialize),
311    serde(bound(
312        serialize = "[u8; N]: serde::Serialize",
313        deserialize = "[u8; N]: serde::Deserialize<'de>"
314    ))
315)]
316pub struct MerkleProof<const N: usize> {
317    /// Tree root.
318    pub root: [u8; N],
319    /// Leaf.
320    pub leaf: [u8; N],
321    /// Decimal representation of the reverse of the path.
322    pub index: usize,
323    /// Siblings.
324    pub siblings: Vec<[u8; N]>,
325}
326
327#[derive(Error, Debug, PartialEq, Eq)]
328pub enum LeanIMTError {
329    #[error("Index out of bounds")]
330    IndexOutOfBounds,
331    #[error("Invalid leaf size")]
332    InvalidLeafSize,
333    #[error("Level out of bounds")]
334    LevelOutOfBounds,
335    #[error("Empty batch insert")]
336    EmptyBatchInsert,
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use std::collections::hash_map::DefaultHasher;
343    use std::hash::{Hash, Hasher};
344
345    fn hash(input: &[u8]) -> [u8; 4] {
346        let mut hasher = DefaultHasher::new();
347
348        for byte in input {
349            byte.hash(&mut hasher);
350        }
351        let hash = hasher.finish();
352
353        let mut result = [0u8; 4];
354        result.copy_from_slice(&hash.to_le_bytes()[..4]);
355        result
356    }
357
358    /// Convert a u32 into a [u8; 4]
359    fn u32_to_leaf(n: u32) -> [u8; 4] {
360        n.to_le_bytes()
361    }
362
363    /// Helper function to generate a vector of leaves from 0 to size - 1.
364    fn generate_leaves(size: u32) -> Vec<[u8; 4]> {
365        (0..size).map(u32_to_leaf).collect()
366    }
367
368    #[test]
369    fn test_new_tree_empty() {
370        let leaves: Vec<[u8; 4]> = vec![];
371        let tree = LeanIMT::new(&leaves, hash).unwrap();
372
373        assert_eq!(tree.size(), 0);
374        assert_eq!(tree.root(), None);
375        assert_eq!(tree.depth(), 0);
376
377        let leaves: &[[u8; 4]] = tree.leaves();
378        let empty_leaves: &[[u8; 4]] = &[];
379        assert_eq!(leaves, empty_leaves);
380    }
381
382    #[test]
383    fn test_insert_single_leaf() {
384        let mut tree = LeanIMT::new(&[], hash).unwrap();
385        let leaf = u32_to_leaf(1);
386        tree.insert(&leaf, hash);
387
388        assert_eq!(tree.root(), Some(leaf));
389        assert_eq!(tree.size(), 1);
390    }
391
392    #[test]
393    fn test_insert_multiple_leaves() {
394        let leaves = generate_leaves(5);
395        let tree_from_batch = LeanIMT::new(&leaves, hash).unwrap();
396
397        // Create an empty tree and insert leaves one by one.
398        let mut tree_iter = LeanIMT::new(&[], hash).unwrap();
399        for leaf in leaves.iter() {
400            tree_iter.insert(leaf, hash);
401        }
402
403        assert_eq!(tree_from_batch, tree_iter);
404    }
405
406    #[test]
407    fn test_index_of_and_contains() {
408        let leaves = generate_leaves(5);
409        let tree = LeanIMT::new(&leaves, hash).unwrap();
410
411        assert_eq!(tree.index_of(&u32_to_leaf(2)), Some(2));
412        assert!(tree.contains(&u32_to_leaf(2)));
413
414        assert_eq!(tree.index_of(&u32_to_leaf(999)), None);
415        assert!(!tree.contains(&u32_to_leaf(999)));
416    }
417
418    #[test]
419    fn test_update_leaf() {
420        let leaves = generate_leaves(5);
421        let mut tree = LeanIMT::new(&leaves, hash).unwrap();
422
423        let new_leaf = u32_to_leaf(42);
424        tree.update(0, &new_leaf, hash).unwrap();
425        assert_eq!(tree.get_leaf(0).unwrap(), new_leaf);
426
427        let proof = tree.generate_proof(0).unwrap();
428        assert!(LeanIMT::verify_proof(&proof, hash));
429    }
430
431    #[test]
432    fn test_generate_and_verify_proof() {
433        let leaves = generate_leaves(5);
434        let tree = LeanIMT::new(&leaves, hash).unwrap();
435
436        for i in 0..leaves.len() {
437            let proof = tree.generate_proof(i).unwrap();
438            assert_eq!(proof.leaf, leaves[i]);
439            assert_eq!(proof.root, tree.root().unwrap());
440            assert!(LeanIMT::verify_proof(&proof, hash));
441        }
442    }
443
444    #[test]
445    fn test_generate_proof_invalid_index() {
446        let leaves = generate_leaves(5);
447        let tree = LeanIMT::new(&leaves, hash).unwrap();
448
449        let result = tree.generate_proof(999);
450        assert!(matches!(result, Err(LeanIMTError::IndexOutOfBounds)));
451    }
452
453    #[test]
454    fn test_update_invalid_index() {
455        let leaves = generate_leaves(5);
456        let mut tree = LeanIMT::new(&leaves, hash).unwrap();
457
458        let result = tree.update(100, &u32_to_leaf(10), hash);
459        assert!(matches!(result, Err(LeanIMTError::IndexOutOfBounds)));
460    }
461}