1pub struct IMT {
2 nodes: Vec<Vec<IMTNode>>,
3 zeroes: Vec<IMTNode>,
4 hash: IMTHashFunction,
5 depth: usize,
6 arity: usize,
7}
8
9pub struct IMTMerkleProof {
10 root: IMTNode,
11 leaf: IMTNode,
12 path_indices: Vec<usize>,
13 siblings: Vec<Vec<IMTNode>>,
14}
15
16pub type IMTNode = String;
17pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
18
19impl IMT {
20 pub fn new(
21 hash: IMTHashFunction,
22 depth: usize,
23 zero_value: IMTNode,
24 arity: usize,
25 leaves: Vec<IMTNode>,
26 ) -> Result<IMT, &'static str> {
27 if leaves.len() > arity.pow(depth as u32) {
28 return Err("The tree cannot contain more than arity^depth leaves");
29 }
30
31 let mut imt = IMT {
32 nodes: vec![vec![]; depth + 1],
33 zeroes: vec![],
34 hash,
35 depth,
36 arity,
37 };
38
39 let mut current_zero = zero_value;
40 for _ in 0..depth {
41 imt.zeroes.push(current_zero.clone());
42 current_zero = (imt.hash)(vec![current_zero; arity]);
43 }
44
45 imt.nodes[0] = leaves;
46
47 for level in 0..depth {
48 for index in 0..((imt.nodes[level].len() as f64 / arity as f64).ceil() as usize) {
49 let position = index * arity;
50 let children: Vec<_> = (0..arity)
51 .map(|i| {
52 imt.nodes[level]
53 .get(position + i)
54 .cloned()
55 .unwrap_or_else(|| imt.zeroes[level].clone())
56 })
57 .collect();
58
59 if let Some(next_level) = imt.nodes.get_mut(level + 1) {
60 next_level.push((imt.hash)(children));
61 }
62 }
63 }
64
65 Ok(imt)
66 }
67
68 pub fn root(&mut self) -> Option<IMTNode> {
69 self.nodes[self.depth].first().cloned()
70 }
71
72 pub fn depth(&self) -> usize {
73 self.depth
74 }
75
76 pub fn nodes(&self) -> Vec<Vec<IMTNode>> {
77 self.nodes.clone()
78 }
79
80 pub fn zeroes(&self) -> Vec<IMTNode> {
81 self.zeroes.clone()
82 }
83
84 pub fn leaves(&self) -> Vec<IMTNode> {
85 self.nodes[0].clone()
86 }
87
88 pub fn arity(&self) -> usize {
89 self.arity
90 }
91
92 pub fn insert(&mut self, leaf: IMTNode) -> Result<(), &'static str> {
93 if self.nodes[0].len() >= self.arity.pow(self.depth as u32) {
94 return Err("The tree is full");
95 }
96
97 let index = self.nodes[0].len();
98 self.nodes[0].push(leaf);
99 self.update(index, self.nodes[0][index].clone())
100 }
101
102 pub fn update(&mut self, mut index: usize, new_leaf: IMTNode) -> Result<(), &'static str> {
103 if index >= self.nodes[0].len() {
104 return Err("The leaf does not exist in this tree");
105 }
106
107 let mut node = new_leaf;
108 self.nodes[0][index].clone_from(&node);
109
110 for level in 0..self.depth {
111 let position = index % self.arity;
112 let level_start_index = index - position;
113 let level_end_index = level_start_index + self.arity;
114
115 let children: Vec<_> = (level_start_index..level_end_index)
116 .map(|i| {
117 self.nodes[level]
118 .get(i)
119 .cloned()
120 .unwrap_or_else(|| self.zeroes[level].clone())
121 })
122 .collect();
123
124 node = (self.hash)(children);
125 index /= self.arity;
126
127 if self.nodes[level + 1].len() <= index {
128 self.nodes[level + 1].push(node.clone());
129 } else {
130 self.nodes[level + 1][index].clone_from(&node);
131 }
132 }
133
134 Ok(())
135 }
136
137 pub fn delete(&mut self, index: usize) -> Result<(), &'static str> {
138 self.update(index, self.zeroes[0].clone())
139 }
140
141 pub fn create_proof(&self, index: usize) -> Result<IMTMerkleProof, &'static str> {
142 if index >= self.nodes[0].len() {
143 return Err("The leaf does not exist in this tree");
144 }
145
146 let mut siblings = Vec::with_capacity(self.depth);
147 let mut path_indices = Vec::with_capacity(self.depth);
148 let mut current_index = index;
149
150 for level in 0..self.depth {
151 let position = current_index % self.arity;
152 let level_start_index = current_index - position;
153 let level_end_index = level_start_index + self.arity;
154
155 path_indices.push(position);
156 let mut level_siblings = Vec::new();
157
158 for i in level_start_index..level_end_index {
159 if i != current_index {
160 level_siblings.push(
161 self.nodes[level]
162 .get(i)
163 .cloned()
164 .unwrap_or_else(|| self.zeroes[level].clone()),
165 );
166 }
167 }
168
169 siblings.push(level_siblings);
170 current_index /= self.arity;
171 }
172
173 Ok(IMTMerkleProof {
174 root: self.nodes[self.depth][0].clone(),
175 leaf: self.nodes[0][index].clone(),
176 path_indices,
177 siblings,
178 })
179 }
180
181 pub fn verify_proof(&self, proof: &IMTMerkleProof) -> bool {
182 let mut node = proof.leaf.clone();
183
184 for (i, sibling) in proof.siblings.iter().enumerate() {
185 let mut children = sibling.clone();
186 children.insert(proof.path_indices[i], node);
187
188 node = (self.hash)(children);
189 }
190
191 node == proof.root
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 fn simple_hash_function(nodes: Vec<String>) -> String {
200 nodes.join(",")
201 }
202
203 #[test]
204 fn test_new_imt() {
205 let hash: IMTHashFunction = simple_hash_function;
206 let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]);
207
208 assert!(imt.is_ok());
209 }
210
211 #[test]
212 fn test_insertion() {
213 let hash: IMTHashFunction = simple_hash_function;
214 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
215
216 assert!(imt.insert("leaf1".to_string()).is_ok());
217 }
218
219 #[test]
220 fn test_delete() {
221 let hash: IMTHashFunction = simple_hash_function;
222 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
223
224 assert!(imt.delete(0).is_ok());
225 }
226
227 #[test]
228 fn test_update() {
229 let hash: IMTHashFunction = simple_hash_function;
230 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
231
232 assert!(imt.update(0, "new_leaf".to_string()).is_ok());
233 }
234
235 #[test]
236 fn test_create_and_verify_proof() {
237 let hash: IMTHashFunction = simple_hash_function;
238 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
239 imt.insert("leaf2".to_string()).unwrap();
240
241 let proof = imt.create_proof(0);
242 assert!(proof.is_ok());
243
244 let proof = proof.unwrap();
245 assert!(imt.verify_proof(&proof));
246 }
247
248 #[test]
249 fn should_not_initialize_with_too_many_leaves() {
250 let hash: IMTHashFunction = simple_hash_function;
251 let leaves = vec![
252 "leaf1".to_string(),
253 "leaf2".to_string(),
254 "leaf3".to_string(),
255 "leaf4".to_string(),
256 "leaf5".to_string(),
257 ];
258 let imt = IMT::new(hash, 2, "zero".to_string(), 2, leaves);
259 assert!(imt.is_err());
260 }
261
262 #[test]
263 fn should_not_insert_in_full_tree() {
264 let hash: IMTHashFunction = simple_hash_function;
265 let mut imt = IMT::new(
266 hash,
267 1,
268 "zero".to_string(),
269 2,
270 vec!["leaf1".to_string(), "leaf2".to_string()],
271 )
272 .unwrap();
273
274 let result = imt.insert("leaf3".to_string());
275 assert!(result.is_err());
276 }
277
278 #[test]
279 fn should_not_delete_nonexistent_leaf() {
280 let hash: IMTHashFunction = simple_hash_function;
281 let mut imt = IMT::new(hash, 3, "zero".to_string(), 2, vec!["leaf1".to_string()]).unwrap();
282
283 let result = imt.delete(1);
284 assert!(result.is_err());
285 }
286
287 #[test]
288 fn test_root() {
289 let hash: IMTHashFunction = simple_hash_function;
290 let mut imt = IMT::new(
291 hash,
292 2,
293 "zero".to_string(),
294 2,
295 vec!["leaf1".to_string(), "leaf2".to_string()],
296 )
297 .unwrap();
298
299 assert_eq!(imt.root(), Some("leaf1,leaf2,zero,zero".to_string()));
300 }
301
302 #[test]
303 fn test_leaves() {
304 let hash: IMTHashFunction = simple_hash_function;
305 let imt = IMT::new(
306 hash,
307 2,
308 "zero".to_string(),
309 2,
310 vec!["leaf1".to_string(), "leaf2".to_string()],
311 )
312 .unwrap();
313
314 assert_eq!(imt.leaves(), vec!["leaf1".to_string(), "leaf2".to_string()]);
315 }
316
317 #[test]
318 fn test_depth_and_arity() {
319 let hash: IMTHashFunction = simple_hash_function;
320 let imt = IMT::new(hash, 3, "zero".to_string(), 2, vec![]).unwrap();
321
322 assert_eq!(imt.depth(), 3);
323 assert_eq!(imt.arity(), 2);
324 }
325}