zk_kit_smt/
smt.rs

1use std::{collections::HashMap, str::FromStr};
2
3use num_bigint::BigInt;
4
5use crate::utils::{
6    get_first_common_elements, get_index_of_last_non_zero_element, is_hexadecimal, key_to_path,
7};
8
9use std::fmt;
10
11#[derive(Debug, PartialEq)]
12pub enum SMTError {
13    KeyAlreadyExist(String),
14    KeyDoesNotExist(String),
15    InvalidParameterType(String, String),
16    InvalidSiblingIndex,
17}
18
19impl fmt::Display for SMTError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            SMTError::KeyAlreadyExist(s) => write!(f, "Key {} already exists", s),
23            SMTError::KeyDoesNotExist(s) => write!(f, "Key {} does not exist", s),
24            SMTError::InvalidParameterType(p, t) => {
25                write!(f, "Parameter {} must be a {}", p, t)
26            },
27            SMTError::InvalidSiblingIndex => write!(f, "Invalid sibling index"),
28        }
29    }
30}
31
32impl std::error::Error for SMTError {}
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum Node {
36    Str(String),
37    BigInt(BigInt),
38}
39
40impl fmt::Display for Node {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Node::Str(s) => write!(f, "{}", s),
44            Node::BigInt(n) => write!(f, "{}", n),
45        }
46    }
47}
48
49impl FromStr for Node {
50    type Err = SMTError;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        if let Ok(bigint) = s.parse::<BigInt>() {
54            Ok(Node::BigInt(bigint))
55        } else if is_hexadecimal(s) {
56            Ok(Node::Str(s.to_string()))
57        } else {
58            Err(SMTError::InvalidParameterType(
59                s.to_string(),
60                "BigInt or hexadecimal string".to_string(),
61            ))
62        }
63    }
64}
65
66pub type Key = Node;
67pub type Value = Node;
68pub type EntryMark = Node;
69
70pub type Entry = (Key, Value, EntryMark);
71pub type ChildNodes = Vec<Node>;
72pub type Siblings = Vec<Node>;
73
74pub type HashFunction = fn(ChildNodes) -> Node;
75
76pub struct EntryResponse {
77    pub entry: Vec<Node>,
78    pub matching_entry: Option<Vec<Node>>,
79    pub siblings: Siblings,
80}
81
82#[allow(dead_code)]
83pub struct MerkleProof {
84    entry_response: EntryResponse,
85    root: Node,
86    membership: bool,
87}
88
89#[allow(dead_code)]
90pub struct SMT {
91    hash: HashFunction,
92    big_numbers: bool,
93    zero_node: Node,
94    entry_mark: Node,
95    nodes: HashMap<Node, Vec<Node>>,
96    root: Node,
97}
98
99impl SMT {
100    /// Initializes a new instance of the Sparse Merkle Tree (SMT).
101    ///
102    /// # Arguments
103    ///
104    /// * `hash` - The hash function used to hash the child nodes.
105    /// * `big_numbers` - A flag indicating whether the SMT supports big numbers or not.
106    ///
107    /// # Returns
108    ///
109    /// A new instance of the SMT.
110    pub fn new(hash: HashFunction, big_numbers: bool) -> Self {
111        let zero_node;
112        let entry_mark;
113
114        if big_numbers {
115            zero_node = Node::BigInt(BigInt::from(0));
116            entry_mark = Node::BigInt(BigInt::from(1));
117        } else {
118            zero_node = Node::Str("0".to_string());
119            entry_mark = Node::Str("1".to_string());
120        }
121
122        SMT {
123            hash,
124            big_numbers,
125            zero_node: zero_node.clone(),
126            entry_mark,
127            nodes: HashMap::new(),
128            root: zero_node,
129        }
130    }
131
132    /// Retrieves the value associated with the given key from the SMT.
133    ///
134    /// # Arguments
135    ///
136    /// * `key` - The key to retrieve the value for.
137    ///
138    /// # Returns
139    ///
140    /// An `Option` containing the value associated with the key, or `None` if the key does not exist.
141    pub fn get(&self, key: Key) -> Option<Value> {
142        let key = key.to_string().parse::<Node>().unwrap();
143
144        let EntryResponse { entry, .. } = self.retrieve_entry(key);
145
146        entry.get(1).cloned()
147    }
148
149    /// Adds a new key-value pair to the SMT.
150    ///
151    /// It retrieves a matching entry or a zero node with a top-down approach and then it updates
152    /// all the hashes of the nodes in the path of the new entry with a bottom up approach.
153    ///
154    /// # Arguments
155    ///
156    /// * `key` - The key to add.
157    /// * `value` - The value associated with the key.
158    ///
159    /// # Returns
160    ///
161    /// An `Result` indicating whether the operation was successful or not.
162    pub fn add(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
163        let key = key.to_string().parse::<Node>().unwrap();
164        let value = value.to_string().parse::<Node>().unwrap();
165
166        let EntryResponse {
167            entry,
168            matching_entry,
169            mut siblings,
170        } = self.retrieve_entry(key.clone());
171
172        if entry.get(1).is_some() {
173            return Err(SMTError::KeyAlreadyExist(key.to_string()));
174        }
175
176        let path = key_to_path(&key.to_string());
177        // If there is a matching entry, its node is saved in the `node` variable, otherwise the
178        // `zero_node` is saved. This node is used below as the first node (starting from the
179        // bottom of the tree) to obtain the new nodes up to the root.
180        let node = if let Some(ref matching_entry) = matching_entry {
181            (self.hash)(matching_entry.clone())
182        } else {
183            self.zero_node.clone()
184        };
185
186        // If there are siblings, the old nodes are deleted and will be re-created below with new hashes.
187        if !siblings.is_empty() {
188            self.delete_old_nodes(node.clone(), &path, &siblings)
189        }
190
191        // If there is a matching entry, further N zero siblings are added in the `siblings` vector,
192        // followed by the matching node itself. N is the number of the first matching bits of the paths.
193        // This is helpful in the non-membership proof verification as explained in the function below.
194        if let Some(matching_entry) = matching_entry {
195            let matching_path = key_to_path(&matching_entry[0].to_string());
196            let mut i = siblings.len();
197
198            while matching_path[i] == path[i] {
199                siblings.push(self.zero_node.clone());
200                i += 1;
201            }
202
203            siblings.push(node.clone());
204        }
205
206        // Adds the new entry and re-creates the nodes of the path with the new hashes with a bottom
207        // up approach. The `add_new_nodes` function returns the new root of the tree.
208        let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
209
210        self.nodes
211            .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
212        self.root = self
213            .add_new_nodes(new_node, &path, &siblings, None)
214            .unwrap();
215
216        Ok(())
217    }
218
219    /// Updates the value associated with the given key in the SMT.
220    ///
221    /// Also in this case, all the hashes of the nodes in the path of the updated entry are updated
222    /// with a bottom up approach.
223    ///
224    /// # Arguments
225    ///
226    /// * `key` - The key to update the value for.
227    /// * `value` - The new value associated with the key.
228    ///
229    /// # Returns
230    ///
231    /// An `Result` indicating whether the operation was successful or not.
232    pub fn update(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
233        let key = key.to_string().parse::<Node>().unwrap();
234        let value = value.to_string().parse::<Node>().unwrap();
235
236        let EntryResponse {
237            entry, siblings, ..
238        } = self.retrieve_entry(key.clone());
239
240        if entry.get(1).is_none() {
241            return Err(SMTError::KeyDoesNotExist(key.to_string()));
242        }
243
244        let path = key_to_path(&key.to_string());
245
246        // Deletes the old nodes and re-creates them with the new hashes.
247        let old_node = (self.hash)(entry.clone());
248        self.nodes.remove(&old_node);
249        self.delete_old_nodes(old_node.clone(), &path, &siblings);
250
251        let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
252        self.nodes
253            .insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
254        self.root = self
255            .add_new_nodes(new_node, &path, &siblings, None)
256            .unwrap();
257
258        Ok(())
259    }
260
261    /// Deletes the key-value pair associated with the given key from the SMT.
262    ///
263    /// Also in this case, all the hashes of the nodes in the path of the deleted entry are updated
264    /// with a bottom up approach.
265    ///
266    /// # Arguments
267    ///
268    /// * `key` - The key to delete.
269    ///
270    /// # Returns
271    ///
272    /// An `Result` indicating whether the operation was successful or not.
273    pub fn delete(&mut self, key: Key) -> Result<(), SMTError> {
274        let key = key.to_string().parse::<Node>().unwrap();
275
276        let EntryResponse {
277            entry,
278            mut siblings,
279            ..
280        } = self.retrieve_entry(key.clone());
281
282        if entry.get(1).is_none() {
283            return Err(SMTError::KeyDoesNotExist(key.to_string()));
284        }
285
286        let path = key_to_path(&key.to_string());
287
288        let node = (self.hash)(entry.clone());
289        self.nodes.remove(&node);
290
291        self.root = self.zero_node.clone();
292
293        // If there are siblings, the old nodes are deleted and will be re-created below with new hashes.
294        if !siblings.is_empty() {
295            self.delete_old_nodes(node.clone(), &path, &siblings);
296
297            // If the last sibling is not a leaf node, it adds all the nodes of the path starting from
298            // a zero node, otherwise it removes the last non-zero sibling from the `siblings` vector
299            // and it starts from it by skipping the last zero nodes.
300            if !self.is_leaf(&siblings.last().cloned().unwrap()) {
301                self.root = self
302                    .add_new_nodes(self.zero_node.clone(), &path, &siblings, None)
303                    .unwrap();
304            } else {
305                let first_sibling = siblings.pop().unwrap();
306                let i = get_index_of_last_non_zero_element(
307                    siblings
308                        .iter()
309                        .map(|s| s.to_string())
310                        .collect::<Vec<String>>()
311                        .iter()
312                        .map(|s| s.as_str())
313                        .collect::<Vec<&str>>(),
314                );
315
316                self.root = self.add_new_nodes(first_sibling, &path, &siblings, Some(i))?;
317            }
318        }
319
320        Ok(())
321    }
322
323    /// Creates a proof to prove the membership or the non-membership of a tree entry.
324    ///
325    /// # Arguments
326    ///
327    /// * `key` - The key to create the proof for.
328    ///
329    /// # Returns
330    ///
331    /// A `MerkleProof` containing the proof information.
332    pub fn create_proof(&self, key: Key) -> MerkleProof {
333        let key = key.to_string().parse::<Node>().unwrap();
334
335        let EntryResponse {
336            entry,
337            matching_entry,
338            siblings,
339        } = self.retrieve_entry(key);
340
341        // If the key exists, the function returns a proof with the entry itself, otherwise it returns
342        // a non-membership proof with the matching entry.
343        MerkleProof {
344            entry_response: EntryResponse {
345                entry: entry.clone(),
346                matching_entry,
347                siblings,
348            },
349            root: self.root.clone(),
350            membership: entry.get(1).is_some(),
351        }
352    }
353
354    /// Verifies a membership or a non-membership proof for a given key in the SMT.
355    ///
356    /// # Arguments
357    ///
358    /// * `merkle_proof` - The Merkle proof to verify.
359    ///
360    /// # Returns
361    ///
362    /// A boolean indicating whether the proof is valid or not.
363    pub fn verify_proof(&self, merkle_proof: MerkleProof) -> bool {
364        // If there is no matching entry, it simply obtains the root hash by using the siblings and the
365        // path of the key.
366        if merkle_proof.entry_response.matching_entry.is_none() {
367            let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
368            // If there is not an entry value, the proof is a non-membership proof. In this case, since there
369            // is not a matching entry, the node is set to a zero node. If there is an entry value, the proof
370            // is a membership proof and the node is set to the hash of the entry.
371            let node = if merkle_proof.entry_response.entry.get(1).is_some() {
372                (self.hash)(merkle_proof.entry_response.entry)
373            } else {
374                self.zero_node.clone()
375            };
376            let root = self.calculate_root(node, &path, &merkle_proof.entry_response.siblings);
377
378            // If the obtained root is equal to the proof root, then the proof is valid.
379            return root == merkle_proof.root;
380        }
381
382        // If there is a matching entry, the proof is definitely a non-membership proof. In this case, it checks
383        // if the matching node belongs to the tree, and then it checks if the number of the first matching bits
384        // of the keys is greater than or equal to the number of the siblings.
385        if let Some(matching_entry) = &merkle_proof.entry_response.matching_entry {
386            let matching_path = key_to_path(&matching_entry[0].to_string());
387            let node = (self.hash)(matching_entry.to_vec());
388            let root =
389                self.calculate_root(node, &matching_path, &merkle_proof.entry_response.siblings);
390
391            if merkle_proof.membership == (root == merkle_proof.root) {
392                let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
393                // Returns the first common bits of the two keys: the non-member key and the matching key.
394                let first_matching_bits = get_first_common_elements(&path, &matching_path);
395
396                // If the non-member key was a key of a tree entry, the depth of the matching node should be
397                // greater than the number of the fisrt matching bits. Otherwise, the depth of the node can be
398                // defined by the number of its siblings.
399                return merkle_proof.entry_response.siblings.len() <= first_matching_bits.len();
400            }
401        }
402
403        false
404    }
405
406    /// Retrieves the entry associated with the given key from the SMT.
407    ///
408    /// If the key passed as parameter exists in the SMT, the function returns the entry itself, otherwise
409    /// it returns the entry with only the key. When there is another matching entry in the same path, it
410    /// returns the matching entry as well.
411    ///
412    /// In any case, the function returns the siblings of the path.
413    ///
414    /// # Arguments
415    ///
416    /// * `key` - The key to retrieve the entry for.
417    ///
418    /// # Returns
419    ///
420    /// An `EntryResponse` struct containing the entry, the matching entry (if any), and the siblings of the leaf node.
421    fn retrieve_entry(&self, key: Key) -> EntryResponse {
422        let path = key_to_path(&key.to_string());
423        let mut siblings: Siblings = Vec::new();
424        let mut node = self.root.clone();
425
426        let mut i = 0;
427
428        // Starting from the root, it traverses the tree until it reaches a leaf node, a zero node,
429        // or a matching entry.
430        while node != self.zero_node {
431            let child_nodes = self.nodes.get(&node).unwrap_or(&Vec::new()).clone();
432            let direction = path[i];
433
434            // If the third element of the child nodes is not None, it means that the node is an entry of the tree.
435            if child_nodes.get(2).is_some() {
436                if child_nodes[0] == key {
437                    // An entry is found with the same key, and it returns it with the siblings.
438                    return EntryResponse {
439                        entry: child_nodes,
440                        matching_entry: None,
441                        siblings,
442                    };
443                }
444
445                // An entry was found with a different key, but the key of this particular entry matches the first 'i'
446                // bits of the key passed as parameter. It can be useful in several functions.
447                return EntryResponse {
448                    entry: vec![key.clone()],
449                    matching_entry: Some(child_nodes),
450                    siblings,
451                };
452            }
453
454            // When it goes down into the tree and follows the path, in every step a node is chosen between left
455            // and right child nodes, and the opposite node is saved in the `siblings` vector.
456            node = child_nodes[direction].clone();
457            siblings.push(child_nodes[1 - direction].clone());
458
459            i += 1;
460        }
461
462        // The path led to a zero node.
463        EntryResponse {
464            entry: vec![key],
465            matching_entry: None,
466            siblings,
467        }
468    }
469
470    /// Calculates the root of the tree by using the given node, the path, and the siblings.
471    ///
472    /// It calculates with a bottom up approach by starting from the node and going up to the root.
473    ///
474    /// # Arguments
475    ///
476    /// * `node` - The node to start the calculation from.
477    /// * `path` - The path of the key.
478    /// * `siblings` - The siblings of the path.
479    ///
480    /// # Returns
481    ///
482    /// The root of the tree.
483    fn calculate_root(&self, mut node: Node, path: &[usize], siblings: &Siblings) -> Node {
484        for i in (0..siblings.len()).rev() {
485            let child_nodes: ChildNodes = if path[i] != 0 {
486                vec![siblings[i].clone(), node.clone()]
487            } else {
488                vec![node.clone(), siblings[i].clone()]
489            };
490
491            node = (self.hash)(child_nodes);
492        }
493
494        node
495    }
496
497    /// Adds new nodes to the tree with the new hashes.
498    ///
499    /// It starts with a bottom up approach until it reaches the root of the tree.
500    ///
501    /// # Arguments
502    ///
503    /// * `node` - The node to start the calculation from.
504    /// * `path` - The path of the key.
505    /// * `siblings` - The siblings of the path.
506    /// * `i` - The index of the sibling to start from.
507    ///
508    /// # Returns
509    ///
510    /// The new root of the tree.
511    fn add_new_nodes(
512        &mut self,
513        mut node: Node,
514        path: &[usize],
515        siblings: &Siblings,
516        i: Option<isize>,
517    ) -> Result<Node, SMTError> {
518        let mut starting_index = if let Some(i) = i {
519            i
520        } else {
521            siblings.len() as isize - 1
522        };
523
524        while starting_index > 0 {
525            if siblings.get(starting_index as usize).is_none() {
526                return Err(SMTError::InvalidSiblingIndex);
527            }
528
529            let child_nodes: ChildNodes = if path.get(starting_index as usize).is_some() {
530                vec![siblings[starting_index as usize].clone(), node.clone()]
531            } else {
532                vec![node.clone(), siblings[starting_index as usize].clone()]
533            };
534
535            node = (self.hash)(child_nodes.clone());
536
537            self.nodes.insert(node.clone(), child_nodes);
538
539            starting_index -= 1;
540        }
541
542        Ok(node)
543    }
544
545    /// Deletes the old nodes of the tree.
546    ///
547    /// It starts with a bottom up approach until it reaches the root of the tree.
548    ///
549    /// # Arguments
550    ///
551    /// * `node` - The node to start the calculation from.
552    /// * `path` - The path of the key.
553    /// * `siblings` - The siblings of the path.
554    fn delete_old_nodes(&mut self, mut node: Node, path: &[usize], siblings: &Siblings) {
555        for i in (0..siblings.len()).rev() {
556            let child_nodes: ChildNodes = if path.get(i).is_some() {
557                vec![siblings[i].clone(), node.clone()]
558            } else {
559                vec![node.clone(), siblings[i].clone()]
560            };
561
562            node = (self.hash)(child_nodes);
563
564            self.nodes.remove(&node);
565        }
566    }
567
568    /// Checks if the given node is a leaf node or not.
569    ///
570    /// # Arguments
571    ///
572    /// * `node` - The node to check.
573    ///
574    /// # Returns
575    ///
576    /// A boolean indicating whether the node is a leaf node or not.
577    fn is_leaf(&self, node: &Node) -> bool {
578        if let Some(child_nodes) = self.nodes.get(node) {
579            child_nodes.get(2).is_some()
580        } else {
581            false
582        }
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    fn hash_function(nodes: Vec<Node>) -> Node {
591        let strings: Vec<String> = nodes.iter().map(|node| node.to_string()).collect();
592        Node::Str(strings.join(","))
593    }
594
595    #[test]
596    fn test_new() {
597        let smt = SMT::new(hash_function, false);
598        assert!(!smt.big_numbers);
599        assert_eq!(smt.zero_node, Node::Str("0".to_string()));
600        assert_eq!(smt.entry_mark, Node::Str("1".to_string()));
601        assert_eq!(smt.nodes, HashMap::new());
602        assert_eq!(smt.root, Node::Str("0".to_string()));
603
604        let smt = SMT::new(hash_function, true);
605        assert!(smt.big_numbers);
606        assert_eq!(smt.zero_node, Node::BigInt(BigInt::from(0)));
607        assert_eq!(smt.entry_mark, Node::BigInt(BigInt::from(1)));
608        assert_eq!(smt.nodes, HashMap::new());
609        assert_eq!(smt.root, Node::BigInt(BigInt::from(0)));
610    }
611
612    #[test]
613    fn test_get() {
614        let mut smt = SMT::new(hash_function, false);
615        let key = Key::Str("aaa".to_string());
616        let value = Value::Str("bbb".to_string());
617        let _ = smt.add(key.clone(), value.clone());
618        let result = smt.get(key.clone());
619        assert_eq!(result, Some(value));
620
621        let key2 = Key::Str("ccc".to_string());
622        let result2 = smt.get(key2.clone());
623        assert_eq!(result2, None);
624
625        let mut smt = SMT::new(hash_function, true);
626        let key = Key::BigInt(BigInt::from(123));
627        let value = Value::BigInt(BigInt::from(456));
628        let _ = smt.add(key.clone(), value.clone());
629        let result = smt.get(key.clone());
630        assert_eq!(result, Some(value));
631    }
632    #[test]
633    fn test_add() {
634        let mut smt = SMT::new(hash_function, false);
635        let key = Key::Str("aaa".to_string());
636        let value = Value::Str("bbb".to_string());
637        let result = smt.add(key.clone(), value.clone());
638        assert!(result.is_ok());
639        assert_eq!(smt.nodes.len(), 1);
640        assert_eq!(
641            smt.nodes.get(&smt.root),
642            Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
643        );
644
645        let mut smt = SMT::new(hash_function, true);
646        let key = Key::BigInt(BigInt::from(123));
647        let value = Value::BigInt(BigInt::from(456));
648        let result = smt.add(key.clone(), value.clone());
649        assert!(result.is_ok());
650        assert_eq!(smt.nodes.len(), 1);
651        assert_eq!(
652            smt.nodes.get(&smt.root),
653            Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
654        );
655    }
656
657    #[test]
658    fn test_update() {
659        let mut smt = SMT::new(hash_function, false);
660        let key = Key::Str("aaa".to_string());
661        let value = Value::Str("bbb".to_string());
662        let _ = smt.add(key.clone(), value.clone());
663
664        let new_value = Value::Str("ccc".to_string());
665        let result = smt.update(key.clone(), new_value.clone());
666        assert!(result.is_ok());
667        assert_eq!(smt.nodes.len(), 1);
668        assert_eq!(
669            smt.nodes.get(&smt.root),
670            Some(&vec![
671                key.clone(),
672                new_value.clone(),
673                smt.entry_mark.clone()
674            ])
675        );
676
677        let key2 = Key::Str("def".to_string());
678        let result2 = smt.update(key2.clone(), new_value.clone());
679        assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
680
681        let mut smt = SMT::new(hash_function, true);
682        let key = Key::BigInt(BigInt::from(123));
683        let value = Value::BigInt(BigInt::from(456));
684        let _ = smt.add(key.clone(), value.clone());
685
686        let new_value = Value::BigInt(BigInt::from(789));
687        let result = smt.update(key.clone(), new_value.clone());
688        assert!(result.is_ok());
689        assert_eq!(smt.nodes.len(), 1);
690        assert_eq!(
691            smt.nodes.get(&smt.root),
692            Some(&vec![
693                key.clone(),
694                new_value.clone(),
695                smt.entry_mark.clone()
696            ])
697        );
698    }
699
700    #[test]
701    fn test_delete() {
702        let mut smt = SMT::new(hash_function, false);
703        let key = Key::Str("abc".to_string());
704        let value = Value::Str("123".to_string());
705        let _ = smt.add(key.clone(), value.clone());
706        let result = smt.delete(key.clone());
707        assert!(result.is_ok());
708        assert_eq!(smt.nodes.len(), 0);
709        assert_eq!(smt.root, smt.zero_node);
710
711        let key2 = Key::Str("def".to_string());
712        let result2 = smt.delete(key2.clone());
713        assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
714
715        let mut smt = SMT::new(hash_function, true);
716        let key = Key::BigInt(BigInt::from(123));
717        let value = Value::BigInt(BigInt::from(456));
718        let _ = smt.add(key.clone(), value.clone());
719        let result = smt.delete(key.clone());
720        assert!(result.is_ok());
721        assert_eq!(smt.nodes.len(), 0);
722        assert_eq!(smt.root, smt.zero_node);
723    }
724
725    #[test]
726    fn test_create_proof() {
727        let mut smt = SMT::new(hash_function, false);
728        let key = Key::Str("abc".to_string());
729        let value = Value::Str("123".to_string());
730        let _ = smt.add(key.clone(), value.clone());
731        let proof = smt.create_proof(key.clone());
732        assert_eq!(proof.root, smt.root);
733
734        let mut smt = SMT::new(hash_function, true);
735        let key = Key::BigInt(BigInt::from(123));
736        let value = Value::BigInt(BigInt::from(456));
737        let _ = smt.add(key.clone(), value.clone());
738        let proof = smt.create_proof(key.clone());
739        assert_eq!(proof.root, smt.root);
740    }
741
742    #[test]
743    fn test_verify_proof() {
744        let mut smt = SMT::new(hash_function, false);
745        let key = Key::Str("abc".to_string());
746        let value = Value::Str("123".to_string());
747        let _ = smt.add(key.clone(), value.clone());
748        let proof = smt.create_proof(key.clone());
749        let result = smt.verify_proof(proof);
750        assert!(result);
751
752        let key2 = Key::Str("def".to_string());
753        let false_proof = MerkleProof {
754            entry_response: EntryResponse {
755                entry: vec![key2.clone()],
756                matching_entry: None,
757                siblings: Vec::new(),
758            },
759            root: smt.root.clone(),
760            membership: false,
761        };
762        let fun = smt.verify_proof(false_proof);
763        assert!(!fun);
764
765        let mut smt = SMT::new(hash_function, true);
766        let key = Key::BigInt(BigInt::from(123));
767        let value = Value::BigInt(BigInt::from(456));
768        let _ = smt.add(key.clone(), value.clone());
769        let proof = smt.create_proof(key.clone());
770        let result = smt.verify_proof(proof);
771        assert!(result);
772
773        let key2 = Key::BigInt(BigInt::from(789));
774        let false_proof = MerkleProof {
775            entry_response: EntryResponse {
776                entry: vec![key2.clone()],
777                matching_entry: None,
778                siblings: Vec::new(),
779            },
780            root: smt.root.clone(),
781            membership: true,
782        };
783        let fun = smt.verify_proof(false_proof);
784        assert!(!fun);
785    }
786
787    #[test]
788    fn test_retrieve_entry() {
789        let smt = SMT::new(hash_function, false);
790        let key = Key::Str("be12".to_string());
791        let entry_response = smt.retrieve_entry(key.clone());
792        assert_eq!(entry_response.entry, vec![key]);
793        assert_eq!(entry_response.matching_entry, None);
794        assert_eq!(entry_response.siblings, Vec::new());
795
796        let smt = SMT::new(hash_function, true);
797        let key = Key::BigInt(BigInt::from(123));
798        let entry_response = smt.retrieve_entry(key.clone());
799        assert_eq!(entry_response.entry, vec![key]);
800        assert_eq!(entry_response.matching_entry, None);
801        assert_eq!(entry_response.siblings, Vec::new());
802    }
803
804    #[test]
805    fn test_calculate_root() {
806        let smt = SMT::new(hash_function, false);
807        let node = Node::Str("node".to_string());
808        let path = &[0, 1, 0];
809        let siblings = vec![
810            Node::Str("sibling1".to_string()),
811            Node::Str("sibling2".to_string()),
812            Node::Str("sibling3".to_string()),
813        ];
814        let root = smt.calculate_root(node.clone(), path, &siblings);
815        assert_eq!(
816            root,
817            Node::Str("sibling2,node,sibling3,sibling1".to_string())
818        );
819
820        let smt = SMT::new(hash_function, true);
821        let node = Node::BigInt(BigInt::from(123));
822        let path = &[1, 0];
823        let siblings = vec![
824            Node::BigInt(BigInt::from(456)),
825            Node::BigInt(BigInt::from(789)),
826        ];
827        let root = smt.calculate_root(node.clone(), path, &siblings);
828        assert_eq!(root, Node::Str("456,123,789".to_string()));
829    }
830
831    #[test]
832    fn test_add_new_nodes() {
833        let mut smt = SMT::new(hash_function, false);
834        let node = Node::Str("node".to_string());
835        let path = &[0, 1, 0];
836        let siblings = vec![
837            Node::Str("sibling1".to_string()),
838            Node::Str("sibling2".to_string()),
839            Node::Str("sibling3".to_string()),
840        ];
841        let new_node = smt
842            .add_new_nodes(node.clone(), path, &siblings, None)
843            .unwrap();
844        assert_eq!(new_node, Node::Str("sibling2,sibling3,node".to_string()));
845
846        let starting_index = smt
847            .add_new_nodes(node.clone(), path, &siblings, Some(1))
848            .unwrap();
849        assert_eq!(starting_index, Node::Str("sibling2,node".to_string()));
850
851        let mut smt = SMT::new(hash_function, true);
852        let node = Node::BigInt(BigInt::from(111));
853        let path = &[1, 0, 0];
854        let siblings = vec![
855            Node::BigInt(BigInt::from(222)),
856            Node::BigInt(BigInt::from(333)),
857            Node::BigInt(BigInt::from(444)),
858        ];
859        let new_node = smt
860            .add_new_nodes(node.clone(), path, &siblings, None)
861            .unwrap();
862        assert_eq!(new_node, Node::Str("333,444,111".to_string()));
863
864        let starting_index = smt
865            .add_new_nodes(node.clone(), path, &siblings, Some(1))
866            .unwrap();
867        assert_eq!(starting_index, Node::Str("333,111".to_string()));
868    }
869
870    #[test]
871    fn test_delete_old_nodes() {
872        let mut smt = SMT::new(hash_function, false);
873        let node = Node::Str("abc".to_string());
874        let path = &[0, 1, 0];
875        let siblings = vec![
876            Node::Str("sibling1".to_string()),
877            Node::Str("sibling2".to_string()),
878            Node::Str("sibling3".to_string()),
879        ];
880        let new_node = smt
881            .add_new_nodes(node.clone(), path, &siblings, None)
882            .unwrap();
883        assert_eq!(new_node, Node::Str("sibling2,sibling3,abc".to_string()));
884        smt.delete_old_nodes(node.clone(), path, &siblings);
885        assert_eq!(smt.nodes.len(), 0);
886
887        let mut smt = SMT::new(hash_function, true);
888        let node = Node::BigInt(BigInt::from(123));
889        let path = &[1, 0];
890        let siblings = vec![
891            Node::BigInt(BigInt::from(456)),
892            Node::BigInt(BigInt::from(789)),
893        ];
894        let new_node = smt
895            .add_new_nodes(node.clone(), path, &siblings, None)
896            .unwrap();
897        assert_eq!(new_node, Node::Str("789,123".to_string()));
898        smt.delete_old_nodes(node.clone(), path, &siblings);
899        assert_eq!(smt.nodes.len(), 0);
900    }
901
902    #[test]
903    fn test_is_leaf() {
904        let mut smt = SMT::new(hash_function, false);
905        let node = Node::Str("abc".to_string());
906        assert!(!smt.is_leaf(&node));
907
908        smt.nodes.insert(
909            Node::Str("abc".to_string()),
910            vec![
911                Node::Str("123".to_string()),
912                Node::Str("456".to_string()),
913                Node::Str("789".to_string()),
914            ],
915        );
916        assert!(smt.is_leaf(&node));
917
918        let mut smt = SMT::new(hash_function, true);
919        let node = Node::BigInt(BigInt::from(123));
920        assert!(!smt.is_leaf(&node));
921
922        smt.nodes.insert(
923            Node::BigInt(BigInt::from(123)),
924            vec![
925                Node::BigInt(BigInt::from(111)),
926                Node::BigInt(BigInt::from(222)),
927                Node::BigInt(BigInt::from(333)),
928            ],
929        );
930        assert!(smt.is_leaf(&node));
931    }
932}