zk_kit_imt/
imt.rs

1pub struct IMT {
2    nodes: Vec<Vec<IMTNode>>,
3    zeroes: Vec<IMTNode>,
4    hash: IMTHashFunction,
5    depth: usize,
6    arity: usize,
7}
8
9pub struct IMTMerkleProof {
10    root: IMTNode,
11    leaf: IMTNode,
12    path_indices: Vec<usize>,
13    siblings: Vec<Vec<IMTNode>>,
14}
15
16pub type IMTNode = String;
17pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
18
19impl IMT {
20    pub fn new(
21        hash: IMTHashFunction,
22        depth: usize,
23        zero_value: IMTNode,
24        arity: usize,
25        leaves: Vec<IMTNode>,
26    ) -> Result<IMT, &'static str> {
27        if leaves.len() > arity.pow(depth as u32) {
28            return Err("The tree cannot contain more than arity^depth leaves");
29        }
30
31        let mut imt = IMT {
32            nodes: vec![vec![]; depth + 1],
33            zeroes: vec![],
34            hash,
35            depth,
36            arity,
37        };
38
39        let mut current_zero = zero_value;
40        for _ in 0..depth {
41            imt.zeroes.push(current_zero.clone());
42            current_zero = (imt.hash)(vec![current_zero; arity]);
43        }
44
45        imt.nodes[0] = leaves;
46
47        for level in 0..depth {
48            for index in 0..((imt.nodes[level].len() as f64 / arity as f64).ceil() as usize) {
49                let position = index * arity;
50                let children: Vec<_> = (0..arity)
51                    .map(|i| {
52                        imt.nodes[level]
53                            .get(position + i)
54                            .cloned()
55                            .unwrap_or_else(|| imt.zeroes[level].clone())
56                    })
57                    .collect();
58
59                if let Some(next_level) = imt.nodes.get_mut(level + 1) {
60                    next_level.push((imt.hash)(children));
61                }
62            }
63        }
64
65        Ok(imt)
66    }
67
68    pub fn root(&mut self) -> Option<IMTNode> {
69        self.nodes[self.depth].first().cloned()
70    }
71
72    pub fn depth(&self) -> usize {
73        self.depth
74    }
75
76    pub fn nodes(&self) -> Vec<Vec<IMTNode>> {
77        self.nodes.clone()
78    }
79
80    pub fn zeroes(&self) -> Vec<IMTNode> {
81        self.zeroes.clone()
82    }
83
84    pub fn leaves(&self) -> Vec<IMTNode> {
85        self.nodes[0].clone()
86    }
87
88    pub fn arity(&self) -> usize {
89        self.arity
90    }
91
92    pub fn insert(&mut self, leaf: IMTNode) -> Result<(), &'static str> {
93        if self.nodes[0].len() >= self.arity.pow(self.depth as u32) {
94            return Err("The tree is full");
95        }
96
97        let index = self.nodes[0].len();
98        self.nodes[0].push(leaf);
99        self.update(index, self.nodes[0][index].clone())
100    }
101
102    pub fn update(&mut self, mut index: usize, new_leaf: IMTNode) -> Result<(), &'static str> {
103        if index >= self.nodes[0].len() {
104            return Err("The leaf does not exist in this tree");
105        }
106
107        let mut node = new_leaf;
108        self.nodes[0][index].clone_from(&node);
109
110        for level in 0..self.depth {
111            let position = index % self.arity;
112            let level_start_index = index - position;
113            let level_end_index = level_start_index + self.arity;
114
115            let children: Vec<_> = (level_start_index..level_end_index)
116                .map(|i| {
117                    self.nodes[level]
118                        .get(i)
119                        .cloned()
120                        .unwrap_or_else(|| self.zeroes[level].clone())
121                })
122                .collect();
123
124            node = (self.hash)(children);
125            index /= self.arity;
126
127            if self.nodes[level + 1].len() <= index {
128                self.nodes[level + 1].push(node.clone());
129            } else {
130                self.nodes[level + 1][index].clone_from(&node);
131            }
132        }
133
134        Ok(())
135    }
136
137    pub fn delete(&mut self, index: usize) -> Result<(), &'static str> {
138        self.update(index, self.zeroes[0].clone())
139    }
140
141    pub fn create_proof(&self, index: usize) -> Result<IMTMerkleProof, &'static str> {
142        if index >= self.nodes[0].len() {
143            return Err("The leaf does not exist in this tree");
144        }
145
146        let mut siblings = Vec::with_capacity(self.depth);
147        let mut path_indices = Vec::with_capacity(self.depth);
148        let mut current_index = index;
149
150        for level in 0..self.depth {
151            let position = current_index % self.arity;
152            let level_start_index = current_index - position;
153            let level_end_index = level_start_index + self.arity;
154
155            path_indices.push(position);
156            let mut level_siblings = Vec::new();
157
158            for i in level_start_index..level_end_index {
159                if i != current_index {
160                    level_siblings.push(
161                        self.nodes[level]
162                            .get(i)
163                            .cloned()
164                            .unwrap_or_else(|| self.zeroes[level].clone()),
165                    );
166                }
167            }
168
169            siblings.push(level_siblings);
170            current_index /= self.arity;
171        }
172
173        Ok(IMTMerkleProof {
174            root: self.nodes[self.depth][0].clone(),
175            leaf: self.nodes[0][index].clone(),
176            path_indices,
177            siblings,
178        })
179    }
180
181    pub fn verify_proof(&self, proof: &IMTMerkleProof) -> bool {
182        let mut node = proof.leaf.clone();
183
184        for (i, sibling) in proof.siblings.iter().enumerate() {
185            let mut children = sibling.clone();
186            children.insert(proof.path_indices[i], node);
187
188            node = (self.hash)(children);
189        }
190
191        node == proof.root
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    fn simple_hash_function(nodes: Vec<String>) -> String {
200        nodes.join(",")
201    }
202
203    #[test]
204    fn test_new_imt() {
205        let hash: IMTHashFunction = simple_hash_function;
206        let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]);
207
208        assert!(imt.is_ok());
209    }
210
211    #[test]
212    fn test_insertion() {
213        let hash: IMTHashFunction = simple_hash_function;
214        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
215
216        assert!(imt.insert("leaf1".to_string()).is_ok());
217    }
218
219    #[test]
220    fn test_delete() {
221        let hash: IMTHashFunction = simple_hash_function;
222        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
223
224        assert!(imt.delete(0).is_ok());
225    }
226
227    #[test]
228    fn test_update() {
229        let hash: IMTHashFunction = simple_hash_function;
230        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
231
232        assert!(imt.update(0, "new_leaf".to_string()).is_ok());
233    }
234
235    #[test]
236    fn test_create_and_verify_proof() {
237        let hash: IMTHashFunction = simple_hash_function;
238        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
239        imt.insert("leaf2".to_string()).unwrap();
240
241        let proof = imt.create_proof(0);
242        assert!(proof.is_ok());
243
244        let proof = proof.unwrap();
245        assert!(imt.verify_proof(&proof));
246    }
247
248    #[test]
249    fn should_not_initialize_with_too_many_leaves() {
250        let hash: IMTHashFunction = simple_hash_function;
251        let leaves = vec![
252            "leaf1".to_string(),
253            "leaf2".to_string(),
254            "leaf3".to_string(),
255            "leaf4".to_string(),
256            "leaf5".to_string(),
257        ];
258        let imt = IMT::new(hash, 2, "zero".to_string(), 2, leaves);
259        assert!(imt.is_err());
260    }
261
262    #[test]
263    fn should_not_insert_in_full_tree() {
264        let hash: IMTHashFunction = simple_hash_function;
265        let mut imt = IMT::new(
266            hash,
267            1,
268            "zero".to_string(),
269            2,
270            vec!["leaf1".to_string(), "leaf2".to_string()],
271        )
272        .unwrap();
273
274        let result = imt.insert("leaf3".to_string());
275        assert!(result.is_err());
276    }
277
278    #[test]
279    fn should_not_delete_nonexistent_leaf() {
280        let hash: IMTHashFunction = simple_hash_function;
281        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
282
283        let result = imt.delete(1);
284        assert!(result.is_err());
285    }
286
287    #[test]
288    fn test_root() {
289        let hash: IMTHashFunction = simple_hash_function;
290        let mut imt = IMT::new(
291            hash,
292            2,
293            "zero".to_string(),
294            2,
295            vec!["leaf1".to_string(), "leaf2".to_string()],
296        )
297        .unwrap();
298
299        assert_eq!(imt.root(), Some("leaf1,leaf2,zero,zero".to_string()));
300    }
301
302    #[test]
303    fn test_leaves() {
304        let hash: IMTHashFunction = simple_hash_function;
305        let imt = IMT::new(
306            hash,
307            2,
308            "zero".to_string(),
309            2,
310            vec!["leaf1".to_string(), "leaf2".to_string()],
311        )
312        .unwrap();
313
314        assert_eq!(imt.leaves(), vec!["leaf1".to_string(), "leaf2".to_string()]);
315    }
316
317    #[test]
318    fn test_depth_and_arity() {
319        let hash: IMTHashFunction = simple_hash_function;
320        let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
321
322        assert_eq!(imt.depth(), 3);
323        assert_eq!(imt.arity(), 2);
324    }
325}