use crate::*;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
const DEPTH_KEY: DBKey = (u64::MAX - 1).to_be_bytes();
const NEXT_INDEX_KEY: DBKey = u64::MAX.to_be_bytes();
const DEFAULT_TREE_DEPTH: usize = 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Key(usize, usize);
impl From<Key> for DBKey {
fn from(key: Key) -> Self {
let cantor_pairing = ((key.0 + key.1) * (key.0 + key.1 + 1) / 2 + key.1) as u64;
cantor_pairing.to_be_bytes()
}
}
pub struct MerkleTree<D, H>
where
D: Database,
H: Hasher,
{
pub db: D,
depth: usize,
next_index: usize,
cache: Vec<H::Fr>,
root: H::Fr,
}
#[derive(Clone, PartialEq, Eq)]
pub struct MerkleProof<H: Hasher>(pub Vec<(H::Fr, u8)>);
impl<D, H> MerkleTree<D, H>
where
D: Database,
H: Hasher,
{
pub fn default(depth: usize) -> PmtreeResult<Self> {
Self::new(depth, D::Config::default())
}
pub fn new(depth: usize, db_config: D::Config) -> PmtreeResult<Self> {
let mut db = D::new(db_config)?;
let depth_val = depth.to_be_bytes().to_vec();
db.put(DEPTH_KEY, depth_val)?;
let next_index = 0usize;
let next_index_val = next_index.to_be_bytes().to_vec();
db.put(NEXT_INDEX_KEY, next_index_val)?;
let mut cache = vec![H::default_leaf(); depth + 1];
cache[depth] = H::default_leaf();
db.put(Key(depth, 0).into(), H::serialize(cache[depth]))?;
for i in (0..depth).rev() {
cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
db.put(Key(i, 0).into(), H::serialize(cache[i]))?;
}
let root = cache[0];
Ok(Self {
db,
depth,
next_index,
cache,
root,
})
}
pub fn load(db_config: D::Config) -> PmtreeResult<Self> {
let db = D::load(db_config)?;
let root = match db.get(Key(0, 0).into())? {
Some(root) => H::deserialize(root),
None => H::default_leaf(),
};
let depth = match db.get(DEPTH_KEY)? {
Some(depth) => usize::from_be_bytes(depth.try_into().unwrap()),
None => DEFAULT_TREE_DEPTH,
};
let next_index = match db.get(NEXT_INDEX_KEY)? {
Some(next_index) => usize::from_be_bytes(next_index.try_into().unwrap()),
None => 0,
};
let mut cache = vec![H::default_leaf(); depth + 1];
cache[depth] = H::default_leaf();
for i in (0..depth).rev() {
cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
}
Ok(Self {
db,
depth,
next_index,
cache,
root,
})
}
pub fn close(&mut self) -> PmtreeResult<()> {
self.db.close()
}
pub fn set(&mut self, key: usize, leaf: H::Fr) -> PmtreeResult<()> {
if key >= self.capacity() {
return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
}
self.db
.put(Key(self.depth, key).into(), H::serialize(leaf))?;
self.recalculate_from(key)?;
self.next_index = max(self.next_index, key + 1);
let next_index_val = self.next_index.to_be_bytes().to_vec();
self.db.put(NEXT_INDEX_KEY, next_index_val)?;
Ok(())
}
fn recalculate_from(&mut self, key: usize) -> PmtreeResult<()> {
let mut depth = self.depth;
let mut i = key;
loop {
let value = self.hash_couple(depth, i)?;
i >>= 1;
depth -= 1;
self.db.put(Key(depth, i).into(), H::serialize(value))?;
if depth == 0 {
self.root = value;
break;
}
}
Ok(())
}
fn hash_couple(&self, depth: usize, key: usize) -> PmtreeResult<H::Fr> {
let b = key & !1;
Ok(H::hash(&[
self.get_elem(Key(depth, b))?,
self.get_elem(Key(depth, b + 1))?,
]))
}
pub fn get_elem(&self, key: Key) -> PmtreeResult<H::Fr> {
let res = self
.db
.get(key.into())?
.map_or(self.cache[key.0], |value| H::deserialize(value));
Ok(res)
}
pub fn delete(&mut self, key: usize) -> PmtreeResult<()> {
if key >= self.next_index {
return Err(PmtreeErrorKind::TreeError(TreeErrorKind::InvalidKey));
}
self.set(key, H::default_leaf())?;
Ok(())
}
pub fn update_next(&mut self, leaf: H::Fr) -> PmtreeResult<()> {
self.set(self.next_index, leaf)?;
Ok(())
}
pub fn set_range<I: IntoIterator<Item = H::Fr>>(
&mut self,
start: usize,
leaves: I,
) -> PmtreeResult<()> {
self.batch_insert(
Some(start),
leaves.into_iter().collect::<Vec<_>>().as_slice(),
)
}
pub fn batch_insert(&mut self, start: Option<usize>, leaves: &[H::Fr]) -> PmtreeResult<()> {
let start = start.unwrap_or(self.next_index);
let end = start + leaves.len();
if end > self.capacity() {
return Err(PmtreeErrorKind::TreeError(TreeErrorKind::MerkleTreeIsFull));
}
let mut subtree = HashMap::<Key, H::Fr>::new();
let root_key = Key(0, 0);
subtree.insert(root_key, self.root);
self.fill_nodes(root_key, start, end, &mut subtree, leaves, start)?;
let subtree = Arc::new(RwLock::new(subtree));
let root_val = rayon::ThreadPoolBuilder::new()
.num_threads(rayon::current_num_threads())
.build()
.unwrap()
.install(|| Self::batch_recalculate(root_key, Arc::clone(&subtree), self.depth));
let subtree = RwLock::into_inner(Arc::try_unwrap(subtree).unwrap()).unwrap();
self.db.put_batch(
subtree
.into_iter()
.map(|(key, value)| (key.into(), H::serialize(value)))
.collect(),
)?;
if end > self.next_index {
self.next_index = end;
self.db
.put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?;
}
self.root = root_val;
Ok(())
}
fn fill_nodes(
&self,
key: Key,
start: usize,
end: usize,
subtree: &mut HashMap<Key, H::Fr>,
leaves: &[H::Fr],
from: usize,
) -> PmtreeResult<()> {
if key.0 == self.depth {
if key.1 >= from {
subtree.insert(key, leaves[key.1 - from]);
}
return Ok(());
}
let left = Key(key.0 + 1, key.1 * 2);
let right = Key(key.0 + 1, key.1 * 2 + 1);
let left_val = self.get_elem(left)?;
let right_val = self.get_elem(right)?;
subtree.insert(left, left_val);
subtree.insert(right, right_val);
let half = 1 << (self.depth - key.0 - 1);
if start < half {
self.fill_nodes(left, start, min(end, half), subtree, leaves, from)?;
}
if end > half {
self.fill_nodes(right, 0, end - half, subtree, leaves, from)?;
}
Ok(())
}
fn batch_recalculate(
key: Key,
subtree: Arc<RwLock<HashMap<Key, H::Fr>>>,
depth: usize,
) -> H::Fr {
let left_child = Key(key.0 + 1, key.1 * 2);
let right_child = Key(key.0 + 1, key.1 * 2 + 1);
if key.0 == depth || !subtree.read().unwrap().contains_key(&left_child) {
return *subtree.read().unwrap().get(&key).unwrap();
}
let (left, right) = rayon::join(
|| Self::batch_recalculate(left_child, Arc::clone(&subtree), depth),
|| Self::batch_recalculate(right_child, Arc::clone(&subtree), depth),
);
let result = H::hash(&[left, right]);
subtree.write().unwrap().insert(key, result);
result
}
pub fn proof(&self, index: usize) -> PmtreeResult<MerkleProof<H>> {
if index >= self.capacity() {
return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
}
let mut witness = Vec::with_capacity(self.depth);
let mut i = index;
let mut depth = self.depth;
while depth != 0 {
i ^= 1;
witness.push((
self.get_elem(Key(depth, i))?,
(1 - (i & 1)).try_into().unwrap(),
));
i >>= 1;
depth -= 1;
}
Ok(MerkleProof(witness))
}
pub fn verify(&self, leaf: &H::Fr, witness: &MerkleProof<H>) -> bool {
let expected_root = witness.compute_root_from(leaf);
self.root() == expected_root
}
pub fn get(&self, key: usize) -> PmtreeResult<H::Fr> {
if key >= self.capacity() {
return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
}
self.get_elem(Key(self.depth, key))
}
pub fn root(&self) -> H::Fr {
self.root
}
pub fn leaves_set(&self) -> usize {
self.next_index
}
pub fn capacity(&self) -> usize {
1 << self.depth
}
pub fn depth(&self) -> usize {
self.depth
}
}
impl<H: Hasher> MerkleProof<H> {
pub fn compute_root_from(&self, leaf: &H::Fr) -> H::Fr {
let mut acc = *leaf;
for w in self.0.iter() {
if w.1 == 0 {
acc = H::hash(&[acc, w.0]);
} else {
acc = H::hash(&[w.0, acc]);
}
}
acc
}
pub fn leaf_index(&self) -> usize {
self.get_path_index()
.into_iter()
.rev()
.fold(0, |acc, digit| (acc << 1) + usize::from(digit))
}
pub fn get_path_index(&self) -> Vec<u8> {
self.0.iter().map(|x| x.1).collect()
}
pub fn get_path_elements(&self) -> Vec<H::Fr> {
self.0.iter().map(|x| x.0).collect()
}
pub fn length(&self) -> usize {
self.0.len()
}
}