use std::{collections::HashMap, str::FromStr};
use num_bigint::BigInt;
use crate::utils::{
get_first_common_elements, get_index_of_last_non_zero_element, is_hexadecimal, key_to_path,
};
use std::fmt;
#[derive(Debug, PartialEq)]
pub enum SMTError {
KeyAlreadyExist(String),
KeyDoesNotExist(String),
InvalidParameterType(String, String),
InvalidSiblingIndex,
}
impl fmt::Display for SMTError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SMTError::KeyAlreadyExist(s) => write!(f, "Key {} already exists", s),
SMTError::KeyDoesNotExist(s) => write!(f, "Key {} does not exist", s),
SMTError::InvalidParameterType(p, t) => {
write!(f, "Parameter {} must be a {}", p, t)
},
SMTError::InvalidSiblingIndex => write!(f, "Invalid sibling index"),
}
}
}
impl std::error::Error for SMTError {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Node {
Str(String),
BigInt(BigInt),
}
impl fmt::Display for Node {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Node::Str(s) => write!(f, "{}", s),
Node::BigInt(n) => write!(f, "{}", n),
}
}
}
impl FromStr for Node {
type Err = SMTError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(bigint) = s.parse::<BigInt>() {
Ok(Node::BigInt(bigint))
} else if is_hexadecimal(s) {
Ok(Node::Str(s.to_string()))
} else {
Err(SMTError::InvalidParameterType(
s.to_string(),
"BigInt or hexadecimal string".to_string(),
))
}
}
}
pub type Key = Node;
pub type Value = Node;
pub type EntryMark = Node;
pub type Entry = (Key, Value, EntryMark);
pub type ChildNodes = Vec<Node>;
pub type Siblings = Vec<Node>;
pub type HashFunction = fn(ChildNodes) -> Node;
pub struct EntryResponse {
pub entry: Vec<Node>,
pub matching_entry: Option<Vec<Node>>,
pub siblings: Siblings,
}
#[allow(dead_code)]
pub struct MerkleProof {
entry_response: EntryResponse,
root: Node,
membership: bool,
}
#[allow(dead_code)]
pub struct SMT {
hash: HashFunction,
big_numbers: bool,
zero_node: Node,
entry_mark: Node,
nodes: HashMap<Node, Vec<Node>>,
root: Node,
}
impl SMT {
pub fn new(hash: HashFunction, big_numbers: bool) -> Self {
let zero_node;
let entry_mark;
if big_numbers {
zero_node = Node::BigInt(BigInt::from(0));
entry_mark = Node::BigInt(BigInt::from(1));
} else {
zero_node = Node::Str("0".to_string());
entry_mark = Node::Str("1".to_string());
}
SMT {
hash,
big_numbers,
zero_node: zero_node.clone(),
entry_mark,
nodes: HashMap::new(),
root: zero_node,
}
}
pub fn get(&self, key: Key) -> Option<Value> {
let key = key.to_string().parse::<Node>().unwrap();
let EntryResponse { entry, .. } = self.retrieve_entry(key);
entry.get(1).cloned()
}
pub fn add(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
let key = key.to_string().parse::<Node>().unwrap();
let value = value.to_string().parse::<Node>().unwrap();
let EntryResponse {
entry,
matching_entry,
mut siblings,
} = self.retrieve_entry(key.clone());
if entry.get(1).is_some() {
return Err(SMTError::KeyAlreadyExist(key.to_string()));
}
let path = key_to_path(&key.to_string());
let node = if let Some(ref matching_entry) = matching_entry {
(self.hash)(matching_entry.clone())
} else {
self.zero_node.clone()
};
if !siblings.is_empty() {
self.delete_old_nodes(node.clone(), &path, &siblings)
}
if let Some(matching_entry) = matching_entry {
let matching_path = key_to_path(&matching_entry[0].to_string());
let mut i = siblings.len();
while matching_path[i] == path[i] {
siblings.push(self.zero_node.clone());
i += 1;
}
siblings.push(node.clone());
}
let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
self.nodes
.insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
self.root = self
.add_new_nodes(new_node, &path, &siblings, None)
.unwrap();
Ok(())
}
pub fn update(&mut self, key: Key, value: Value) -> Result<(), SMTError> {
let key = key.to_string().parse::<Node>().unwrap();
let value = value.to_string().parse::<Node>().unwrap();
let EntryResponse {
entry, siblings, ..
} = self.retrieve_entry(key.clone());
if entry.get(1).is_none() {
return Err(SMTError::KeyDoesNotExist(key.to_string()));
}
let path = key_to_path(&key.to_string());
let old_node = (self.hash)(entry.clone());
self.nodes.remove(&old_node);
self.delete_old_nodes(old_node.clone(), &path, &siblings);
let new_node = (self.hash)(vec![key.clone(), value.clone(), self.entry_mark.clone()]);
self.nodes
.insert(new_node.clone(), vec![key, value, self.entry_mark.clone()]);
self.root = self
.add_new_nodes(new_node, &path, &siblings, None)
.unwrap();
Ok(())
}
pub fn delete(&mut self, key: Key) -> Result<(), SMTError> {
let key = key.to_string().parse::<Node>().unwrap();
let EntryResponse {
entry,
mut siblings,
..
} = self.retrieve_entry(key.clone());
if entry.get(1).is_none() {
return Err(SMTError::KeyDoesNotExist(key.to_string()));
}
let path = key_to_path(&key.to_string());
let node = (self.hash)(entry.clone());
self.nodes.remove(&node);
self.root = self.zero_node.clone();
if !siblings.is_empty() {
self.delete_old_nodes(node.clone(), &path, &siblings);
if !self.is_leaf(&siblings.last().cloned().unwrap()) {
self.root = self
.add_new_nodes(self.zero_node.clone(), &path, &siblings, None)
.unwrap();
} else {
let first_sibling = siblings.pop().unwrap();
let i = get_index_of_last_non_zero_element(
siblings
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
.iter()
.map(|s| s.as_str())
.collect::<Vec<&str>>(),
);
self.root = self.add_new_nodes(first_sibling, &path, &siblings, Some(i))?;
}
}
Ok(())
}
pub fn create_proof(&self, key: Key) -> MerkleProof {
let key = key.to_string().parse::<Node>().unwrap();
let EntryResponse {
entry,
matching_entry,
siblings,
} = self.retrieve_entry(key);
MerkleProof {
entry_response: EntryResponse {
entry: entry.clone(),
matching_entry,
siblings,
},
root: self.root.clone(),
membership: entry.get(1).is_some(),
}
}
pub fn verify_proof(&self, merkle_proof: MerkleProof) -> bool {
if merkle_proof.entry_response.matching_entry.is_none() {
let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
let node = if merkle_proof.entry_response.entry.get(1).is_some() {
(self.hash)(merkle_proof.entry_response.entry)
} else {
self.zero_node.clone()
};
let root = self.calculate_root(node, &path, &merkle_proof.entry_response.siblings);
return root == merkle_proof.root;
}
if let Some(matching_entry) = &merkle_proof.entry_response.matching_entry {
let matching_path = key_to_path(&matching_entry[0].to_string());
let node = (self.hash)(matching_entry.to_vec());
let root =
self.calculate_root(node, &matching_path, &merkle_proof.entry_response.siblings);
if merkle_proof.membership == (root == merkle_proof.root) {
let path = key_to_path(&merkle_proof.entry_response.entry[0].to_string());
let first_matching_bits = get_first_common_elements(&path, &matching_path);
return merkle_proof.entry_response.siblings.len() <= first_matching_bits.len();
}
}
false
}
fn retrieve_entry(&self, key: Key) -> EntryResponse {
let path = key_to_path(&key.to_string());
let mut siblings: Siblings = Vec::new();
let mut node = self.root.clone();
let mut i = 0;
while node != self.zero_node {
let child_nodes = self.nodes.get(&node).unwrap_or(&Vec::new()).clone();
let direction = path[i];
if child_nodes.get(2).is_some() {
if child_nodes[0] == key {
return EntryResponse {
entry: child_nodes,
matching_entry: None,
siblings,
};
}
return EntryResponse {
entry: vec![key.clone()],
matching_entry: Some(child_nodes),
siblings,
};
}
node = child_nodes[direction].clone();
siblings.push(child_nodes[1 - direction].clone());
i += 1;
}
EntryResponse {
entry: vec![key],
matching_entry: None,
siblings,
}
}
fn calculate_root(&self, mut node: Node, path: &[usize], siblings: &Siblings) -> Node {
for i in (0..siblings.len()).rev() {
let child_nodes: ChildNodes = if path[i] != 0 {
vec![siblings[i].clone(), node.clone()]
} else {
vec![node.clone(), siblings[i].clone()]
};
node = (self.hash)(child_nodes);
}
node
}
fn add_new_nodes(
&mut self,
mut node: Node,
path: &[usize],
siblings: &Siblings,
i: Option<isize>,
) -> Result<Node, SMTError> {
let mut starting_index = if let Some(i) = i {
i
} else {
siblings.len() as isize - 1
};
while starting_index > 0 {
if siblings.get(starting_index as usize).is_none() {
return Err(SMTError::InvalidSiblingIndex);
}
let child_nodes: ChildNodes = if path.get(starting_index as usize).is_some() {
vec![siblings[starting_index as usize].clone(), node.clone()]
} else {
vec![node.clone(), siblings[starting_index as usize].clone()]
};
node = (self.hash)(child_nodes.clone());
self.nodes.insert(node.clone(), child_nodes);
starting_index -= 1;
}
Ok(node)
}
fn delete_old_nodes(&mut self, mut node: Node, path: &[usize], siblings: &Siblings) {
for i in (0..siblings.len()).rev() {
let child_nodes: ChildNodes = if path.get(i).is_some() {
vec![siblings[i].clone(), node.clone()]
} else {
vec![node.clone(), siblings[i].clone()]
};
node = (self.hash)(child_nodes);
self.nodes.remove(&node);
}
}
fn is_leaf(&self, node: &Node) -> bool {
if let Some(child_nodes) = self.nodes.get(node) {
child_nodes.get(2).is_some()
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hash_function(nodes: Vec<Node>) -> Node {
let strings: Vec<String> = nodes.iter().map(|node| node.to_string()).collect();
Node::Str(strings.join(","))
}
#[test]
fn test_new() {
let smt = SMT::new(hash_function, false);
assert!(!smt.big_numbers);
assert_eq!(smt.zero_node, Node::Str("0".to_string()));
assert_eq!(smt.entry_mark, Node::Str("1".to_string()));
assert_eq!(smt.nodes, HashMap::new());
assert_eq!(smt.root, Node::Str("0".to_string()));
let smt = SMT::new(hash_function, true);
assert!(smt.big_numbers);
assert_eq!(smt.zero_node, Node::BigInt(BigInt::from(0)));
assert_eq!(smt.entry_mark, Node::BigInt(BigInt::from(1)));
assert_eq!(smt.nodes, HashMap::new());
assert_eq!(smt.root, Node::BigInt(BigInt::from(0)));
}
#[test]
fn test_get() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("aaa".to_string());
let value = Value::Str("bbb".to_string());
let _ = smt.add(key.clone(), value.clone());
let result = smt.get(key.clone());
assert_eq!(result, Some(value));
let key2 = Key::Str("ccc".to_string());
let result2 = smt.get(key2.clone());
assert_eq!(result2, None);
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let _ = smt.add(key.clone(), value.clone());
let result = smt.get(key.clone());
assert_eq!(result, Some(value));
}
#[test]
fn test_add() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("aaa".to_string());
let value = Value::Str("bbb".to_string());
let result = smt.add(key.clone(), value.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 1);
assert_eq!(
smt.nodes.get(&smt.root),
Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
);
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let result = smt.add(key.clone(), value.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 1);
assert_eq!(
smt.nodes.get(&smt.root),
Some(&vec![key.clone(), value.clone(), smt.entry_mark.clone()])
);
}
#[test]
fn test_update() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("aaa".to_string());
let value = Value::Str("bbb".to_string());
let _ = smt.add(key.clone(), value.clone());
let new_value = Value::Str("ccc".to_string());
let result = smt.update(key.clone(), new_value.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 1);
assert_eq!(
smt.nodes.get(&smt.root),
Some(&vec![
key.clone(),
new_value.clone(),
smt.entry_mark.clone()
])
);
let key2 = Key::Str("def".to_string());
let result2 = smt.update(key2.clone(), new_value.clone());
assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let _ = smt.add(key.clone(), value.clone());
let new_value = Value::BigInt(BigInt::from(789));
let result = smt.update(key.clone(), new_value.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 1);
assert_eq!(
smt.nodes.get(&smt.root),
Some(&vec![
key.clone(),
new_value.clone(),
smt.entry_mark.clone()
])
);
}
#[test]
fn test_delete() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("abc".to_string());
let value = Value::Str("123".to_string());
let _ = smt.add(key.clone(), value.clone());
let result = smt.delete(key.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 0);
assert_eq!(smt.root, smt.zero_node);
let key2 = Key::Str("def".to_string());
let result2 = smt.delete(key2.clone());
assert_eq!(result2, Err(SMTError::KeyDoesNotExist(key2.to_string())));
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let _ = smt.add(key.clone(), value.clone());
let result = smt.delete(key.clone());
assert!(result.is_ok());
assert_eq!(smt.nodes.len(), 0);
assert_eq!(smt.root, smt.zero_node);
}
#[test]
fn test_create_proof() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("abc".to_string());
let value = Value::Str("123".to_string());
let _ = smt.add(key.clone(), value.clone());
let proof = smt.create_proof(key.clone());
assert_eq!(proof.root, smt.root);
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let _ = smt.add(key.clone(), value.clone());
let proof = smt.create_proof(key.clone());
assert_eq!(proof.root, smt.root);
}
#[test]
fn test_verify_proof() {
let mut smt = SMT::new(hash_function, false);
let key = Key::Str("abc".to_string());
let value = Value::Str("123".to_string());
let _ = smt.add(key.clone(), value.clone());
let proof = smt.create_proof(key.clone());
let result = smt.verify_proof(proof);
assert!(result);
let key2 = Key::Str("def".to_string());
let false_proof = MerkleProof {
entry_response: EntryResponse {
entry: vec![key2.clone()],
matching_entry: None,
siblings: Vec::new(),
},
root: smt.root.clone(),
membership: false,
};
let fun = smt.verify_proof(false_proof);
assert!(!fun);
let mut smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let value = Value::BigInt(BigInt::from(456));
let _ = smt.add(key.clone(), value.clone());
let proof = smt.create_proof(key.clone());
let result = smt.verify_proof(proof);
assert!(result);
let key2 = Key::BigInt(BigInt::from(789));
let false_proof = MerkleProof {
entry_response: EntryResponse {
entry: vec![key2.clone()],
matching_entry: None,
siblings: Vec::new(),
},
root: smt.root.clone(),
membership: true,
};
let fun = smt.verify_proof(false_proof);
assert!(!fun);
}
#[test]
fn test_retrieve_entry() {
let smt = SMT::new(hash_function, false);
let key = Key::Str("be12".to_string());
let entry_response = smt.retrieve_entry(key.clone());
assert_eq!(entry_response.entry, vec![key]);
assert_eq!(entry_response.matching_entry, None);
assert_eq!(entry_response.siblings, Vec::new());
let smt = SMT::new(hash_function, true);
let key = Key::BigInt(BigInt::from(123));
let entry_response = smt.retrieve_entry(key.clone());
assert_eq!(entry_response.entry, vec![key]);
assert_eq!(entry_response.matching_entry, None);
assert_eq!(entry_response.siblings, Vec::new());
}
#[test]
fn test_calculate_root() {
let smt = SMT::new(hash_function, false);
let node = Node::Str("node".to_string());
let path = &[0, 1, 0];
let siblings = vec![
Node::Str("sibling1".to_string()),
Node::Str("sibling2".to_string()),
Node::Str("sibling3".to_string()),
];
let root = smt.calculate_root(node.clone(), path, &siblings);
assert_eq!(
root,
Node::Str("sibling2,node,sibling3,sibling1".to_string())
);
let smt = SMT::new(hash_function, true);
let node = Node::BigInt(BigInt::from(123));
let path = &[1, 0];
let siblings = vec![
Node::BigInt(BigInt::from(456)),
Node::BigInt(BigInt::from(789)),
];
let root = smt.calculate_root(node.clone(), path, &siblings);
assert_eq!(root, Node::Str("456,123,789".to_string()));
}
#[test]
fn test_add_new_nodes() {
let mut smt = SMT::new(hash_function, false);
let node = Node::Str("node".to_string());
let path = &[0, 1, 0];
let siblings = vec![
Node::Str("sibling1".to_string()),
Node::Str("sibling2".to_string()),
Node::Str("sibling3".to_string()),
];
let new_node = smt
.add_new_nodes(node.clone(), path, &siblings, None)
.unwrap();
assert_eq!(new_node, Node::Str("sibling2,sibling3,node".to_string()));
let starting_index = smt
.add_new_nodes(node.clone(), path, &siblings, Some(1))
.unwrap();
assert_eq!(starting_index, Node::Str("sibling2,node".to_string()));
let mut smt = SMT::new(hash_function, true);
let node = Node::BigInt(BigInt::from(111));
let path = &[1, 0, 0];
let siblings = vec![
Node::BigInt(BigInt::from(222)),
Node::BigInt(BigInt::from(333)),
Node::BigInt(BigInt::from(444)),
];
let new_node = smt
.add_new_nodes(node.clone(), path, &siblings, None)
.unwrap();
assert_eq!(new_node, Node::Str("333,444,111".to_string()));
let starting_index = smt
.add_new_nodes(node.clone(), path, &siblings, Some(1))
.unwrap();
assert_eq!(starting_index, Node::Str("333,111".to_string()));
}
#[test]
fn test_delete_old_nodes() {
let mut smt = SMT::new(hash_function, false);
let node = Node::Str("abc".to_string());
let path = &[0, 1, 0];
let siblings = vec![
Node::Str("sibling1".to_string()),
Node::Str("sibling2".to_string()),
Node::Str("sibling3".to_string()),
];
let new_node = smt
.add_new_nodes(node.clone(), path, &siblings, None)
.unwrap();
assert_eq!(new_node, Node::Str("sibling2,sibling3,abc".to_string()));
smt.delete_old_nodes(node.clone(), path, &siblings);
assert_eq!(smt.nodes.len(), 0);
let mut smt = SMT::new(hash_function, true);
let node = Node::BigInt(BigInt::from(123));
let path = &[1, 0];
let siblings = vec![
Node::BigInt(BigInt::from(456)),
Node::BigInt(BigInt::from(789)),
];
let new_node = smt
.add_new_nodes(node.clone(), path, &siblings, None)
.unwrap();
assert_eq!(new_node, Node::Str("789,123".to_string()));
smt.delete_old_nodes(node.clone(), path, &siblings);
assert_eq!(smt.nodes.len(), 0);
}
#[test]
fn test_is_leaf() {
let mut smt = SMT::new(hash_function, false);
let node = Node::Str("abc".to_string());
assert!(!smt.is_leaf(&node));
smt.nodes.insert(
Node::Str("abc".to_string()),
vec![
Node::Str("123".to_string()),
Node::Str("456".to_string()),
Node::Str("789".to_string()),
],
);
assert!(smt.is_leaf(&node));
let mut smt = SMT::new(hash_function, true);
let node = Node::BigInt(BigInt::from(123));
assert!(!smt.is_leaf(&node));
smt.nodes.insert(
Node::BigInt(BigInt::from(123)),
vec![
Node::BigInt(BigInt::from(111)),
Node::BigInt(BigInt::from(222)),
Node::BigInt(BigInt::from(333)),
],
);
assert!(smt.is_leaf(&node));
}
}