1pub struct IMT {
2 nodes: Vec<Vec<IMTNode>>,
3 zeroes: Vec<IMTNode>,
4 hash: IMTHashFunction,
5 depth: usize,
6 arity: usize,
7}
8
9pub struct IMTMerkleProof {
10 root: IMTNode,
11 leaf: IMTNode,
12 path_indices: Vec<usize>,
13 siblings: Vec<Vec<IMTNode>>,
14}
15
16pub type IMTNode = String;
17pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
18
19impl IMT {
20 pub fn new(
21 hash: IMTHashFunction,
22 depth: usize,
23 zero_value: IMTNode,
24 arity: usize,
25 leaves: Vec<IMTNode>,
26 ) -> Result<IMT, &'static str> {
27 if leaves.len() > arity.pow(depth as u32) {
28 return Err("The tree cannot contain more than arity^depth leaves");
29 }
30
31 let mut imt = IMT {
32 nodes: vec![vec![]; depth + 1],
33 zeroes: vec![],
34 hash,
35 depth,
36 arity,
37 };
38
39 let mut current_zero = zero_value;
40 for _ in 0..depth {
41 imt.zeroes.push(current_zero.clone());
42 current_zero = (imt.hash)(vec![current_zero; arity]);
43 }
44
45 imt.nodes[0] = leaves;
46
47 for level in 0..depth {
48 for index in 0..((imt.nodes[level].len() as f64 / arity as f64).ceil() as usize) {
49 let position = index * arity;
50 let children: Vec<_> = (0..arity)
51 .map(|i| {
52 imt.nodes[level]
53 .get(position + i)
54 .cloned()
55 .unwrap_or_else(|| imt.zeroes[level].clone())
56 })
57 .collect();
58
59 if let Some(next_level) = imt.nodes.get_mut(level + 1) {
60 next_level.push((imt.hash)(children));
61 }
62 }
63 }
64
65 Ok(imt)
66 }
67
68 pub fn root(&mut self) -> Option<IMTNode> {
69 self.nodes[self.depth].first().cloned()
70 }
71
72 pub fn depth(&self) -> usize {
73 self.depth
74 }
75
76 pub fn nodes(&self) -> Vec<Vec<IMTNode>> {
77 self.nodes.clone()
78 }
79
80 pub fn zeroes(&self) -> Vec<IMTNode> {
81 self.zeroes.clone()
82 }
83
84 pub fn leaves(&self) -> Vec<IMTNode> {
85 self.nodes[0].clone()
86 }
87
88 pub fn arity(&self) -> usize {
89 self.arity
90 }
91
92 pub fn index_of(&self, leaf: IMTNode) -> Option<usize> {
93 self.nodes.get(0)?.iter().position(|n| n == &leaf)
94 }
95
96 pub fn insert(&mut self, leaf: IMTNode) -> Result<(), &'static str> {
97 if self.nodes[0].len() >= self.arity.pow(self.depth as u32) {
98 return Err("The tree is full");
99 }
100
101 let index = self.nodes[0].len();
102 self.nodes[0].push(leaf);
103 self.update(index, self.nodes[0][index].clone())
104 }
105
106 pub fn update(&mut self, mut index: usize, new_leaf: IMTNode) -> Result<(), &'static str> {
107 if index >= self.nodes[0].len() {
108 return Err("The leaf does not exist in this tree");
109 }
110
111 let mut node = new_leaf;
112 self.nodes[0][index].clone_from(&node);
113
114 for level in 0..self.depth {
115 let position = index % self.arity;
116 let level_start_index = index - position;
117 let level_end_index = level_start_index + self.arity;
118
119 let children: Vec<_> = (level_start_index..level_end_index)
120 .map(|i| {
121 self.nodes[level]
122 .get(i)
123 .cloned()
124 .unwrap_or_else(|| self.zeroes[level].clone())
125 })
126 .collect();
127
128 node = (self.hash)(children);
129 index /= self.arity;
130
131 if self.nodes[level + 1].len() <= index {
132 self.nodes[level + 1].push(node.clone());
133 } else {
134 self.nodes[level + 1][index].clone_from(&node);
135 }
136 }
137
138 Ok(())
139 }
140
141 pub fn delete(&mut self, index: usize) -> Result<(), &'static str> {
142 self.update(index, self.zeroes[0].clone())
143 }
144
145 pub fn create_proof(&self, index: usize) -> Result<IMTMerkleProof, &'static str> {
146 if index >= self.nodes[0].len() {
147 return Err("The leaf does not exist in this tree");
148 }
149
150 let mut siblings = Vec::with_capacity(self.depth);
151 let mut path_indices = Vec::with_capacity(self.depth);
152 let mut current_index = index;
153
154 for level in 0..self.depth {
155 let position = current_index % self.arity;
156 let level_start_index = current_index - position;
157 let level_end_index = level_start_index + self.arity;
158
159 path_indices.push(position);
160 let mut level_siblings = Vec::new();
161
162 for i in level_start_index..level_end_index {
163 if i != current_index {
164 level_siblings.push(
165 self.nodes[level]
166 .get(i)
167 .cloned()
168 .unwrap_or_else(|| self.zeroes[level].clone()),
169 );
170 }
171 }
172
173 siblings.push(level_siblings);
174 current_index /= self.arity;
175 }
176
177 Ok(IMTMerkleProof {
178 root: self.nodes[self.depth][0].clone(),
179 leaf: self.nodes[0][index].clone(),
180 path_indices,
181 siblings,
182 })
183 }
184
185 pub fn verify_proof(&self, proof: &IMTMerkleProof) -> bool {
186 let mut node = proof.leaf.clone();
187
188 for (i, sibling) in proof.siblings.iter().enumerate() {
189 let mut children = sibling.clone();
190 children.insert(proof.path_indices[i], node);
191
192 node = (self.hash)(children);
193 }
194
195 node == proof.root
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 fn simple_hash_function(nodes: Vec<String>) -> String {
204 nodes.join(",")
205 }
206
207 #[test]
208 fn test_new_imt() {
209 let hash: IMTHashFunction = simple_hash_function;
210 let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]);
211
212 assert!(imt.is_ok());
213 }
214
215 #[test]
216 fn test_insertion() {
217 let hash: IMTHashFunction = simple_hash_function;
218 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
219
220 assert!(imt.insert("leaf1".to_string()).is_ok());
221 }
222
223 #[test]
224 fn test_delete() {
225 let hash: IMTHashFunction = simple_hash_function;
226 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
227
228 assert!(imt.delete(0).is_ok());
229 }
230
231 #[test]
232 fn test_update() {
233 let hash: IMTHashFunction = simple_hash_function;
234 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
235
236 assert!(imt.update(0, "new_leaf".to_string()).is_ok());
237 }
238
239 #[test]
240 fn test_create_and_verify_proof() {
241 let hash: IMTHashFunction = simple_hash_function;
242 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
243 imt.insert("leaf2".to_string()).unwrap();
244
245 let proof = imt.create_proof(0);
246 assert!(proof.is_ok());
247
248 let proof = proof.unwrap();
249 assert!(imt.verify_proof(&proof));
250 }
251
252 #[test]
253 fn should_not_initialize_with_too_many_leaves() {
254 let hash: IMTHashFunction = simple_hash_function;
255 let leaves = vec![
256 "leaf1".to_string(),
257 "leaf2".to_string(),
258 "leaf3".to_string(),
259 "leaf4".to_string(),
260 "leaf5".to_string(),
261 ];
262 let imt = IMT::new(hash, 2, "zero".to_string(), 2, leaves);
263 assert!(imt.is_err());
264 }
265
266 #[test]
267 fn should_not_insert_in_full_tree() {
268 let hash: IMTHashFunction = simple_hash_function;
269 let mut imt = IMT::new(
270 hash,
271 1,
272 "zero".to_string(),
273 2,
274 vec!["leaf1".to_string(), "leaf2".to_string()],
275 )
276 .unwrap();
277
278 let result = imt.insert("leaf3".to_string());
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn should_not_delete_nonexistent_leaf() {
284 let hash: IMTHashFunction = simple_hash_function;
285 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
286
287 let result = imt.delete(1);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_root() {
293 let hash: IMTHashFunction = simple_hash_function;
294 let mut imt = IMT::new(
295 hash,
296 2,
297 "zero".to_string(),
298 2,
299 vec!["leaf1".to_string(), "leaf2".to_string()],
300 )
301 .unwrap();
302
303 assert_eq!(imt.root(), Some("leaf1,leaf2,zero,zero".to_string()));
304 }
305
306 #[test]
307 fn test_leaves() {
308 let hash: IMTHashFunction = simple_hash_function;
309 let imt = IMT::new(
310 hash,
311 2,
312 "zero".to_string(),
313 2,
314 vec!["leaf1".to_string(), "leaf2".to_string()],
315 )
316 .unwrap();
317
318 assert_eq!(imt.leaves(), vec!["leaf1".to_string(), "leaf2".to_string()]);
319 }
320
321 #[test]
322 fn test_depth_and_arity() {
323 let hash: IMTHashFunction = simple_hash_function;
324 let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
325
326 assert_eq!(imt.depth(), 3);
327 assert_eq!(imt.arity(), 2);
328 }
329
330 #[test]
331 fn test_index_of() {
332 let hash: IMTHashFunction = simple_hash_function;
333 let imt = IMT::new(
334 hash,
335 2,
336 "zero".to_string(),
337 2,
338 vec!["leaf1".to_string(), "leaf2".to_string()],
339 )
340 .unwrap();
341
342 assert_eq!(imt.index_of("leaf1".to_string()), Some(0));
343 assert_eq!(imt.index_of("leaf2".to_string()), Some(1));
344 assert_eq!(imt.index_of("leaf3".to_string()), None);
345 }
346}