zk_kit_imt/
imt.rs

1pub struct IMT {
2    nodes: Vec<Vec<IMTNode>>,
3    zeroes: Vec<IMTNode>,
4    hash: IMTHashFunction,
5    depth: usize,
6    arity: usize,
7}
8
9pub struct IMTMerkleProof {
10    root: IMTNode,
11    leaf: IMTNode,
12    path_indices: Vec<usize>,
13    siblings: Vec<Vec<IMTNode>>,
14}
15
16pub type IMTNode = String;
17pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
18
19impl IMT {
20    pub fn new(
21        hash: IMTHashFunction,
22        depth: usize,
23        zero_value: IMTNode,
24        arity: usize,
25        leaves: Vec<IMTNode>,
26    ) -> Result<IMT, &'static str> {
27        if leaves.len() > arity.pow(depth as u32) {
28            return Err("The tree cannot contain more than arity^depth leaves");
29        }
30
31        let mut imt = IMT {
32            nodes: vec![vec![]; depth + 1],
33            zeroes: vec![],
34            hash,
35            depth,
36            arity,
37        };
38
39        let mut current_zero = zero_value;
40        for _ in 0..depth {
41            imt.zeroes.push(current_zero.clone());
42            current_zero = (imt.hash)(vec![current_zero; arity]);
43        }
44
45        imt.nodes[0] = leaves;
46
47        for level in 0..depth {
48            for index in 0..((imt.nodes[level].len() as f64 / arity as f64).ceil() as usize) {
49                let position = index * arity;
50                let children: Vec<_> = (0..arity)
51                    .map(|i| {
52                        imt.nodes[level]
53                            .get(position + i)
54                            .cloned()
55                            .unwrap_or_else(|| imt.zeroes[level].clone())
56                    })
57                    .collect();
58
59                if let Some(next_level) = imt.nodes.get_mut(level + 1) {
60                    next_level.push((imt.hash)(children));
61                }
62            }
63        }
64
65        Ok(imt)
66    }
67
68    pub fn root(&mut self) -> Option<IMTNode> {
69        self.nodes[self.depth].first().cloned()
70    }
71
72    pub fn depth(&self) -> usize {
73        self.depth
74    }
75
76    pub fn nodes(&self) -> Vec<Vec<IMTNode>> {
77        self.nodes.clone()
78    }
79
80    pub fn zeroes(&self) -> Vec<IMTNode> {
81        self.zeroes.clone()
82    }
83
84    pub fn leaves(&self) -> Vec<IMTNode> {
85        self.nodes[0].clone()
86    }
87
88    pub fn arity(&self) -> usize {
89        self.arity
90    }
91
92    pub fn index_of(&self, leaf: IMTNode) -> Option<usize> {
93        self.nodes.get(0)?.iter().position(|n| n == &leaf)
94    }
95
96    pub fn insert(&mut self, leaf: IMTNode) -> Result<(), &'static str> {
97        if self.nodes[0].len() >= self.arity.pow(self.depth as u32) {
98            return Err("The tree is full");
99        }
100
101        let index = self.nodes[0].len();
102        self.nodes[0].push(leaf);
103        self.update(index, self.nodes[0][index].clone())
104    }
105
106    pub fn update(&mut self, mut index: usize, new_leaf: IMTNode) -> Result<(), &'static str> {
107        if index >= self.nodes[0].len() {
108            return Err("The leaf does not exist in this tree");
109        }
110
111        let mut node = new_leaf;
112        self.nodes[0][index].clone_from(&node);
113
114        for level in 0..self.depth {
115            let position = index % self.arity;
116            let level_start_index = index - position;
117            let level_end_index = level_start_index + self.arity;
118
119            let children: Vec<_> = (level_start_index..level_end_index)
120                .map(|i| {
121                    self.nodes[level]
122                        .get(i)
123                        .cloned()
124                        .unwrap_or_else(|| self.zeroes[level].clone())
125                })
126                .collect();
127
128            node = (self.hash)(children);
129            index /= self.arity;
130
131            if self.nodes[level + 1].len() <= index {
132                self.nodes[level + 1].push(node.clone());
133            } else {
134                self.nodes[level + 1][index].clone_from(&node);
135            }
136        }
137
138        Ok(())
139    }
140
141    pub fn delete(&mut self, index: usize) -> Result<(), &'static str> {
142        self.update(index, self.zeroes[0].clone())
143    }
144
145    pub fn create_proof(&self, index: usize) -> Result<IMTMerkleProof, &'static str> {
146        if index >= self.nodes[0].len() {
147            return Err("The leaf does not exist in this tree");
148        }
149
150        let mut siblings = Vec::with_capacity(self.depth);
151        let mut path_indices = Vec::with_capacity(self.depth);
152        let mut current_index = index;
153
154        for level in 0..self.depth {
155            let position = current_index % self.arity;
156            let level_start_index = current_index - position;
157            let level_end_index = level_start_index + self.arity;
158
159            path_indices.push(position);
160            let mut level_siblings = Vec::new();
161
162            for i in level_start_index..level_end_index {
163                if i != current_index {
164                    level_siblings.push(
165                        self.nodes[level]
166                            .get(i)
167                            .cloned()
168                            .unwrap_or_else(|| self.zeroes[level].clone()),
169                    );
170                }
171            }
172
173            siblings.push(level_siblings);
174            current_index /= self.arity;
175        }
176
177        Ok(IMTMerkleProof {
178            root: self.nodes[self.depth][0].clone(),
179            leaf: self.nodes[0][index].clone(),
180            path_indices,
181            siblings,
182        })
183    }
184
185    pub fn verify_proof(&self, proof: &IMTMerkleProof) -> bool {
186        let mut node = proof.leaf.clone();
187
188        for (i, sibling) in proof.siblings.iter().enumerate() {
189            let mut children = sibling.clone();
190            children.insert(proof.path_indices[i], node);
191
192            node = (self.hash)(children);
193        }
194
195        node == proof.root
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    fn simple_hash_function(nodes: Vec<String>) -> String {
204        nodes.join(",")
205    }
206
207    #[test]
208    fn test_new_imt() {
209        let hash: IMTHashFunction = simple_hash_function;
210        let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]);
211
212        assert!(imt.is_ok());
213    }
214
215    #[test]
216    fn test_insertion() {
217        let hash: IMTHashFunction = simple_hash_function;
218        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
219
220        assert!(imt.insert("leaf1".to_string()).is_ok());
221    }
222
223    #[test]
224    fn test_delete() {
225        let hash: IMTHashFunction = simple_hash_function;
226        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
227
228        assert!(imt.delete(0).is_ok());
229    }
230
231    #[test]
232    fn test_update() {
233        let hash: IMTHashFunction = simple_hash_function;
234        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
235
236        assert!(imt.update(0, "new_leaf".to_string()).is_ok());
237    }
238
239    #[test]
240    fn test_create_and_verify_proof() {
241        let hash: IMTHashFunction = simple_hash_function;
242        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
243        imt.insert("leaf2".to_string()).unwrap();
244
245        let proof = imt.create_proof(0);
246        assert!(proof.is_ok());
247
248        let proof = proof.unwrap();
249        assert!(imt.verify_proof(&proof));
250    }
251
252    #[test]
253    fn should_not_initialize_with_too_many_leaves() {
254        let hash: IMTHashFunction = simple_hash_function;
255        let leaves = vec![
256            "leaf1".to_string(),
257            "leaf2".to_string(),
258            "leaf3".to_string(),
259            "leaf4".to_string(),
260            "leaf5".to_string(),
261        ];
262        let imt = IMT::new(hash, 2, "zero".to_string(), 2, leaves);
263        assert!(imt.is_err());
264    }
265
266    #[test]
267    fn should_not_insert_in_full_tree() {
268        let hash: IMTHashFunction = simple_hash_function;
269        let mut imt = IMT::new(
270            hash,
271            1,
272            "zero".to_string(),
273            2,
274            vec!["leaf1".to_string(), "leaf2".to_string()],
275        )
276        .unwrap();
277
278        let result = imt.insert("leaf3".to_string());
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn should_not_delete_nonexistent_leaf() {
284        let hash: IMTHashFunction = simple_hash_function;
285        let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
286
287        let result = imt.delete(1);
288        assert!(result.is_err());
289    }
290
291    #[test]
292    fn test_root() {
293        let hash: IMTHashFunction = simple_hash_function;
294        let mut imt = IMT::new(
295            hash,
296            2,
297            "zero".to_string(),
298            2,
299            vec!["leaf1".to_string(), "leaf2".to_string()],
300        )
301        .unwrap();
302
303        assert_eq!(imt.root(), Some("leaf1,leaf2,zero,zero".to_string()));
304    }
305
306    #[test]
307    fn test_leaves() {
308        let hash: IMTHashFunction = simple_hash_function;
309        let imt = IMT::new(
310            hash,
311            2,
312            "zero".to_string(),
313            2,
314            vec!["leaf1".to_string(), "leaf2".to_string()],
315        )
316        .unwrap();
317
318        assert_eq!(imt.leaves(), vec!["leaf1".to_string(), "leaf2".to_string()]);
319    }
320
321    #[test]
322    fn test_depth_and_arity() {
323        let hash: IMTHashFunction = simple_hash_function;
324        let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
325
326        assert_eq!(imt.depth(), 3);
327        assert_eq!(imt.arity(), 2);
328    }
329
330    #[test]
331    fn test_index_of() {
332        let hash: IMTHashFunction = simple_hash_function;
333        let imt = IMT::new(
334            hash,
335            2,
336            "zero".to_string(),
337            2,
338            vec!["leaf1".to_string(), "leaf2".to_string()],
339        )
340        .unwrap();
341
342        assert_eq!(imt.index_of("leaf1".to_string()), Some(0));
343        assert_eq!(imt.index_of("leaf2".to_string()), Some(1));
344        assert_eq!(imt.index_of("leaf3".to_string()), None);
345    }
346}