zk_kit_pmt/
tree.rs

1use crate::*;
2
3use std::cmp::{max, min};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7// db[DEPTH_KEY] = depth
8const DEPTH_KEY: DBKey = (u64::MAX - 1).to_be_bytes();
9
10// db[NEXT_INDEX_KEY] = next_index;
11const NEXT_INDEX_KEY: DBKey = u64::MAX.to_be_bytes();
12
13// Default tree depth
14const DEFAULT_TREE_DEPTH: usize = 20;
15
16// Denotes keys (depth, index) in Merkle Tree. Can be converted to DBKey
17// TODO! Think about using hashing for that
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct Key(usize, usize);
20impl From<Key> for DBKey {
21    fn from(key: Key) -> Self {
22        let cantor_pairing = ((key.0 + key.1) * (key.0 + key.1 + 1) / 2 + key.1) as u64;
23        cantor_pairing.to_be_bytes()
24    }
25}
26
27/// The Merkle Tree structure
28pub struct MerkleTree<D, H>
29where
30    D: Database,
31    H: Hasher,
32{
33    pub db: D,
34    depth: usize,
35    next_index: usize,
36    cache: Vec<H::Fr>,
37    root: H::Fr,
38}
39
40/// The Merkle proof structure
41#[derive(Clone, PartialEq, Eq)]
42pub struct MerkleProof<H: Hasher>(pub Vec<(H::Fr, u8)>);
43
44impl<D, H> MerkleTree<D, H>
45where
46    D: Database,
47    H: Hasher,
48{
49    /// Creates tree with specified depth and default "pmtree_db" dbpath.
50    pub fn default(depth: usize) -> PmtreeResult<Self> {
51        Self::new(depth, D::Config::default())
52    }
53
54    /// Creates new `MerkleTree` and store it to the specified path/db
55    pub fn new(depth: usize, db_config: D::Config) -> PmtreeResult<Self> {
56        // Create new db instance
57        let mut db = D::new(db_config)?;
58
59        // Insert depth val into db
60        let depth_val = depth.to_be_bytes().to_vec();
61        db.put(DEPTH_KEY, depth_val)?;
62
63        // Insert next_index val into db
64        let next_index = 0usize;
65        let next_index_val = next_index.to_be_bytes().to_vec();
66        db.put(NEXT_INDEX_KEY, next_index_val)?;
67
68        // Cache nodes
69        let mut cache = vec![H::default_leaf(); depth + 1];
70
71        // Initialize one branch of the `Merkle Tree` from bottom to top
72        cache[depth] = H::default_leaf();
73        db.put(Key(depth, 0).into(), H::serialize(cache[depth]))?;
74        for i in (0..depth).rev() {
75            cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
76            db.put(Key(i, 0).into(), H::serialize(cache[i]))?;
77        }
78
79        let root = cache[0];
80
81        Ok(Self {
82            db,
83            depth,
84            next_index,
85            cache,
86            root,
87        })
88    }
89
90    /// Loads existing Merkle Tree from the specified path/db
91    pub fn load(db_config: D::Config) -> PmtreeResult<Self> {
92        // Load existing db instance
93        let db = D::load(db_config)?;
94
95        // Load root
96        let root = match db.get(Key(0, 0).into())? {
97            Some(root) => H::deserialize(root),
98            None => H::default_leaf(),
99        };
100
101        // Load depth & next_index values from db
102        let depth = match db.get(DEPTH_KEY)? {
103            Some(depth) => usize::from_be_bytes(depth.try_into().unwrap()),
104            None => DEFAULT_TREE_DEPTH,
105        };
106
107        let next_index = match db.get(NEXT_INDEX_KEY)? {
108            Some(next_index) => usize::from_be_bytes(next_index.try_into().unwrap()),
109            None => 0,
110        };
111
112        // Load cache vec
113        let mut cache = vec![H::default_leaf(); depth + 1];
114        cache[depth] = H::default_leaf();
115        for i in (0..depth).rev() {
116            cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
117        }
118
119        Ok(Self {
120            db,
121            depth,
122            next_index,
123            cache,
124            root,
125        })
126    }
127
128    /// Closes the db connection
129    pub fn close(&mut self) -> PmtreeResult<()> {
130        self.db.close()
131    }
132
133    /// Sets a leaf at the specified tree index
134    pub fn set(&mut self, key: usize, leaf: H::Fr) -> PmtreeResult<()> {
135        if key >= self.capacity() {
136            return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
137        }
138
139        self.db
140            .put(Key(self.depth, key).into(), H::serialize(leaf))?;
141        self.recalculate_from(key)?;
142
143        // Update next_index in memory
144        self.next_index = max(self.next_index, key + 1);
145
146        // Update next_index in db
147        let next_index_val = self.next_index.to_be_bytes().to_vec();
148        self.db.put(NEXT_INDEX_KEY, next_index_val)?;
149
150        Ok(())
151    }
152
153    // Recalculates `Merkle Tree` from the specified key
154    fn recalculate_from(&mut self, key: usize) -> PmtreeResult<()> {
155        let mut depth = self.depth;
156        let mut i = key;
157
158        loop {
159            let value = self.hash_couple(depth, i)?;
160            i >>= 1;
161            depth -= 1;
162            self.db.put(Key(depth, i).into(), H::serialize(value))?;
163
164            if depth == 0 {
165                self.root = value;
166                break;
167            }
168        }
169
170        Ok(())
171    }
172
173    // Hashes the correct couple for the key
174    fn hash_couple(&self, depth: usize, key: usize) -> PmtreeResult<H::Fr> {
175        let b = key & !1;
176        Ok(H::hash(&[
177            self.get_elem(Key(depth, b))?,
178            self.get_elem(Key(depth, b + 1))?,
179        ]))
180    }
181
182    // Returns elem by the key
183    pub fn get_elem(&self, key: Key) -> PmtreeResult<H::Fr> {
184        let res = self
185            .db
186            .get(key.into())?
187            .map_or(self.cache[key.0], |value| H::deserialize(value));
188
189        Ok(res)
190    }
191
192    /// Deletes a leaf at the `key` by setting it to its default value
193    pub fn delete(&mut self, key: usize) -> PmtreeResult<()> {
194        if key >= self.next_index {
195            return Err(PmtreeErrorKind::TreeError(TreeErrorKind::InvalidKey));
196        }
197
198        self.set(key, H::default_leaf())?;
199
200        Ok(())
201    }
202
203    /// Inserts a leaf to the next available index
204    pub fn update_next(&mut self, leaf: H::Fr) -> PmtreeResult<()> {
205        self.set(self.next_index, leaf)?;
206
207        Ok(())
208    }
209
210    /// Batch insertion from starting index
211    pub fn set_range<I: IntoIterator<Item = H::Fr>>(
212        &mut self,
213        start: usize,
214        leaves: I,
215    ) -> PmtreeResult<()> {
216        self.batch_insert(
217            Some(start),
218            leaves.into_iter().collect::<Vec<_>>().as_slice(),
219        )
220    }
221
222    /// Batch insertion, updates the tree in parallel.
223    pub fn batch_insert(&mut self, start: Option<usize>, leaves: &[H::Fr]) -> PmtreeResult<()> {
224        let start = start.unwrap_or(self.next_index);
225        let end = start + leaves.len();
226
227        if end > self.capacity() {
228            return Err(PmtreeErrorKind::TreeError(TreeErrorKind::MerkleTreeIsFull));
229        }
230
231        let mut subtree = HashMap::<Key, H::Fr>::new();
232
233        let root_key = Key(0, 0);
234
235        subtree.insert(root_key, self.root);
236        self.fill_nodes(root_key, start, end, &mut subtree, leaves, start)?;
237
238        let subtree = Arc::new(RwLock::new(subtree));
239
240        let root_val = rayon::ThreadPoolBuilder::new()
241            .num_threads(rayon::current_num_threads())
242            .build()
243            .unwrap()
244            .install(|| Self::batch_recalculate(root_key, Arc::clone(&subtree), self.depth));
245
246        let subtree = RwLock::into_inner(Arc::try_unwrap(subtree).unwrap()).unwrap();
247
248        self.db.put_batch(
249            subtree
250                .into_iter()
251                .map(|(key, value)| (key.into(), H::serialize(value)))
252                .collect(),
253        )?;
254
255        // Update next_index value in db
256        if end > self.next_index {
257            self.next_index = end;
258            self.db
259                .put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?;
260        }
261
262        // Update root value in memory
263        self.root = root_val;
264
265        Ok(())
266    }
267
268    // Fills hashmap subtree
269    fn fill_nodes(
270        &self,
271        key: Key,
272        start: usize,
273        end: usize,
274        subtree: &mut HashMap<Key, H::Fr>,
275        leaves: &[H::Fr],
276        from: usize,
277    ) -> PmtreeResult<()> {
278        if key.0 == self.depth {
279            if key.1 >= from {
280                subtree.insert(key, leaves[key.1 - from]);
281            }
282            return Ok(());
283        }
284
285        let left = Key(key.0 + 1, key.1 * 2);
286        let right = Key(key.0 + 1, key.1 * 2 + 1);
287
288        let left_val = self.get_elem(left)?;
289        let right_val = self.get_elem(right)?;
290
291        subtree.insert(left, left_val);
292        subtree.insert(right, right_val);
293
294        let half = 1 << (self.depth - key.0 - 1);
295
296        if start < half {
297            self.fill_nodes(left, start, min(end, half), subtree, leaves, from)?;
298        }
299
300        if end > half {
301            self.fill_nodes(right, 0, end - half, subtree, leaves, from)?;
302        }
303
304        Ok(())
305    }
306
307    // Recalculates tree in parallel (in-memory)
308    fn batch_recalculate(
309        key: Key,
310        subtree: Arc<RwLock<HashMap<Key, H::Fr>>>,
311        depth: usize,
312    ) -> H::Fr {
313        let left_child = Key(key.0 + 1, key.1 * 2);
314        let right_child = Key(key.0 + 1, key.1 * 2 + 1);
315
316        if key.0 == depth || !subtree.read().unwrap().contains_key(&left_child) {
317            return *subtree.read().unwrap().get(&key).unwrap();
318        }
319
320        let (left, right) = rayon::join(
321            || Self::batch_recalculate(left_child, Arc::clone(&subtree), depth),
322            || Self::batch_recalculate(right_child, Arc::clone(&subtree), depth),
323        );
324
325        let result = H::hash(&[left, right]);
326
327        subtree.write().unwrap().insert(key, result);
328
329        result
330    }
331
332    /// Computes a Merkle proof for the leaf at the specified index
333    pub fn proof(&self, index: usize) -> PmtreeResult<MerkleProof<H>> {
334        if index >= self.capacity() {
335            return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
336        }
337
338        let mut witness = Vec::with_capacity(self.depth);
339
340        let mut i = index;
341        let mut depth = self.depth;
342        while depth != 0 {
343            i ^= 1;
344            witness.push((
345                self.get_elem(Key(depth, i))?,
346                (1 - (i & 1)).try_into().unwrap(),
347            ));
348            i >>= 1;
349            depth -= 1;
350        }
351
352        Ok(MerkleProof(witness))
353    }
354
355    /// Verifies a Merkle proof with respect to the input leaf and the tree root
356    pub fn verify(&self, leaf: &H::Fr, witness: &MerkleProof<H>) -> bool {
357        let expected_root = witness.compute_root_from(leaf);
358
359        self.root() == expected_root
360    }
361
362    /// Returns the leaf by the key
363    pub fn get(&self, key: usize) -> PmtreeResult<H::Fr> {
364        if key >= self.capacity() {
365            return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
366        }
367
368        self.get_elem(Key(self.depth, key))
369    }
370
371    /// Returns the root of the tree
372    pub fn root(&self) -> H::Fr {
373        self.root
374    }
375
376    /// Returns the total number of leaves set
377    pub fn leaves_set(&self) -> usize {
378        self.next_index
379    }
380
381    /// Returns the capacity of the tree, i.e. the maximum number of leaves
382    pub fn capacity(&self) -> usize {
383        1 << self.depth
384    }
385
386    /// Returns the depth of the tree
387    pub fn depth(&self) -> usize {
388        self.depth
389    }
390}
391
392impl<H: Hasher> MerkleProof<H> {
393    /// Computes the Merkle root by iteratively hashing specified Merkle proof with specified leaf
394    pub fn compute_root_from(&self, leaf: &H::Fr) -> H::Fr {
395        let mut acc = *leaf;
396        for w in self.0.iter() {
397            if w.1 == 0 {
398                acc = H::hash(&[acc, w.0]);
399            } else {
400                acc = H::hash(&[w.0, acc]);
401            }
402        }
403
404        acc
405    }
406
407    /// Computes the leaf index corresponding to a Merkle proof
408    pub fn leaf_index(&self) -> usize {
409        self.get_path_index()
410            .into_iter()
411            .rev()
412            .fold(0, |acc, digit| (acc << 1) + usize::from(digit))
413    }
414
415    /// Returns the path indexes forming a Merkle Proof
416    pub fn get_path_index(&self) -> Vec<u8> {
417        self.0.iter().map(|x| x.1).collect()
418    }
419
420    /// Returns the path elements forming a Merkle proof
421    pub fn get_path_elements(&self) -> Vec<H::Fr> {
422        self.0.iter().map(|x| x.0).collect()
423    }
424
425    /// Returns the length of a Merkle proof
426    pub fn length(&self) -> usize {
427        self.0.len()
428    }
429}