1use crate::*;
2
3use std::cmp::{max, min};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7const DEPTH_KEY: DBKey = (u64::MAX - 1).to_be_bytes();
9
10const NEXT_INDEX_KEY: DBKey = u64::MAX.to_be_bytes();
12
13const DEFAULT_TREE_DEPTH: usize = 20;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct Key(usize, usize);
20impl From<Key> for DBKey {
21 fn from(key: Key) -> Self {
22 let cantor_pairing = ((key.0 + key.1) * (key.0 + key.1 + 1) / 2 + key.1) as u64;
23 cantor_pairing.to_be_bytes()
24 }
25}
26
27pub struct MerkleTree<D, H>
29where
30 D: Database,
31 H: Hasher,
32{
33 pub db: D,
34 depth: usize,
35 next_index: usize,
36 cache: Vec<H::Fr>,
37 root: H::Fr,
38}
39
40#[derive(Clone, PartialEq, Eq)]
42pub struct MerkleProof<H: Hasher>(pub Vec<(H::Fr, u8)>);
43
44impl<D, H> MerkleTree<D, H>
45where
46 D: Database,
47 H: Hasher,
48{
49 pub fn default(depth: usize) -> PmtreeResult<Self> {
51 Self::new(depth, D::Config::default())
52 }
53
54 pub fn new(depth: usize, db_config: D::Config) -> PmtreeResult<Self> {
56 let mut db = D::new(db_config)?;
58
59 let depth_val = depth.to_be_bytes().to_vec();
61 db.put(DEPTH_KEY, depth_val)?;
62
63 let next_index = 0usize;
65 let next_index_val = next_index.to_be_bytes().to_vec();
66 db.put(NEXT_INDEX_KEY, next_index_val)?;
67
68 let mut cache = vec![H::default_leaf(); depth + 1];
70
71 cache[depth] = H::default_leaf();
73 db.put(Key(depth, 0).into(), H::serialize(cache[depth]))?;
74 for i in (0..depth).rev() {
75 cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
76 db.put(Key(i, 0).into(), H::serialize(cache[i]))?;
77 }
78
79 let root = cache[0];
80
81 Ok(Self {
82 db,
83 depth,
84 next_index,
85 cache,
86 root,
87 })
88 }
89
90 pub fn load(db_config: D::Config) -> PmtreeResult<Self> {
92 let db = D::load(db_config)?;
94
95 let root = match db.get(Key(0, 0).into())? {
97 Some(root) => H::deserialize(root),
98 None => H::default_leaf(),
99 };
100
101 let depth = match db.get(DEPTH_KEY)? {
103 Some(depth) => usize::from_be_bytes(depth.try_into().unwrap()),
104 None => DEFAULT_TREE_DEPTH,
105 };
106
107 let next_index = match db.get(NEXT_INDEX_KEY)? {
108 Some(next_index) => usize::from_be_bytes(next_index.try_into().unwrap()),
109 None => 0,
110 };
111
112 let mut cache = vec![H::default_leaf(); depth + 1];
114 cache[depth] = H::default_leaf();
115 for i in (0..depth).rev() {
116 cache[i] = H::hash(&[cache[i + 1], cache[i + 1]]);
117 }
118
119 Ok(Self {
120 db,
121 depth,
122 next_index,
123 cache,
124 root,
125 })
126 }
127
128 pub fn close(&mut self) -> PmtreeResult<()> {
130 self.db.close()
131 }
132
133 pub fn set(&mut self, key: usize, leaf: H::Fr) -> PmtreeResult<()> {
135 if key >= self.capacity() {
136 return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
137 }
138
139 self.db
140 .put(Key(self.depth, key).into(), H::serialize(leaf))?;
141 self.recalculate_from(key)?;
142
143 self.next_index = max(self.next_index, key + 1);
145
146 let next_index_val = self.next_index.to_be_bytes().to_vec();
148 self.db.put(NEXT_INDEX_KEY, next_index_val)?;
149
150 Ok(())
151 }
152
153 fn recalculate_from(&mut self, key: usize) -> PmtreeResult<()> {
155 let mut depth = self.depth;
156 let mut i = key;
157
158 loop {
159 let value = self.hash_couple(depth, i)?;
160 i >>= 1;
161 depth -= 1;
162 self.db.put(Key(depth, i).into(), H::serialize(value))?;
163
164 if depth == 0 {
165 self.root = value;
166 break;
167 }
168 }
169
170 Ok(())
171 }
172
173 fn hash_couple(&self, depth: usize, key: usize) -> PmtreeResult<H::Fr> {
175 let b = key & !1;
176 Ok(H::hash(&[
177 self.get_elem(Key(depth, b))?,
178 self.get_elem(Key(depth, b + 1))?,
179 ]))
180 }
181
182 pub fn get_elem(&self, key: Key) -> PmtreeResult<H::Fr> {
184 let res = self
185 .db
186 .get(key.into())?
187 .map_or(self.cache[key.0], |value| H::deserialize(value));
188
189 Ok(res)
190 }
191
192 pub fn delete(&mut self, key: usize) -> PmtreeResult<()> {
194 if key >= self.next_index {
195 return Err(PmtreeErrorKind::TreeError(TreeErrorKind::InvalidKey));
196 }
197
198 self.set(key, H::default_leaf())?;
199
200 Ok(())
201 }
202
203 pub fn update_next(&mut self, leaf: H::Fr) -> PmtreeResult<()> {
205 self.set(self.next_index, leaf)?;
206
207 Ok(())
208 }
209
210 pub fn set_range<I: IntoIterator<Item = H::Fr>>(
212 &mut self,
213 start: usize,
214 leaves: I,
215 ) -> PmtreeResult<()> {
216 self.batch_insert(
217 Some(start),
218 leaves.into_iter().collect::<Vec<_>>().as_slice(),
219 )
220 }
221
222 pub fn batch_insert(&mut self, start: Option<usize>, leaves: &[H::Fr]) -> PmtreeResult<()> {
224 let start = start.unwrap_or(self.next_index);
225 let end = start + leaves.len();
226
227 if end > self.capacity() {
228 return Err(PmtreeErrorKind::TreeError(TreeErrorKind::MerkleTreeIsFull));
229 }
230
231 let mut subtree = HashMap::<Key, H::Fr>::new();
232
233 let root_key = Key(0, 0);
234
235 subtree.insert(root_key, self.root);
236 self.fill_nodes(root_key, start, end, &mut subtree, leaves, start)?;
237
238 let subtree = Arc::new(RwLock::new(subtree));
239
240 let root_val = rayon::ThreadPoolBuilder::new()
241 .num_threads(rayon::current_num_threads())
242 .build()
243 .unwrap()
244 .install(|| Self::batch_recalculate(root_key, Arc::clone(&subtree), self.depth));
245
246 let subtree = RwLock::into_inner(Arc::try_unwrap(subtree).unwrap()).unwrap();
247
248 self.db.put_batch(
249 subtree
250 .into_iter()
251 .map(|(key, value)| (key.into(), H::serialize(value)))
252 .collect(),
253 )?;
254
255 if end > self.next_index {
257 self.next_index = end;
258 self.db
259 .put(NEXT_INDEX_KEY, self.next_index.to_be_bytes().to_vec())?;
260 }
261
262 self.root = root_val;
264
265 Ok(())
266 }
267
268 fn fill_nodes(
270 &self,
271 key: Key,
272 start: usize,
273 end: usize,
274 subtree: &mut HashMap<Key, H::Fr>,
275 leaves: &[H::Fr],
276 from: usize,
277 ) -> PmtreeResult<()> {
278 if key.0 == self.depth {
279 if key.1 >= from {
280 subtree.insert(key, leaves[key.1 - from]);
281 }
282 return Ok(());
283 }
284
285 let left = Key(key.0 + 1, key.1 * 2);
286 let right = Key(key.0 + 1, key.1 * 2 + 1);
287
288 let left_val = self.get_elem(left)?;
289 let right_val = self.get_elem(right)?;
290
291 subtree.insert(left, left_val);
292 subtree.insert(right, right_val);
293
294 let half = 1 << (self.depth - key.0 - 1);
295
296 if start < half {
297 self.fill_nodes(left, start, min(end, half), subtree, leaves, from)?;
298 }
299
300 if end > half {
301 self.fill_nodes(right, 0, end - half, subtree, leaves, from)?;
302 }
303
304 Ok(())
305 }
306
307 fn batch_recalculate(
309 key: Key,
310 subtree: Arc<RwLock<HashMap<Key, H::Fr>>>,
311 depth: usize,
312 ) -> H::Fr {
313 let left_child = Key(key.0 + 1, key.1 * 2);
314 let right_child = Key(key.0 + 1, key.1 * 2 + 1);
315
316 if key.0 == depth || !subtree.read().unwrap().contains_key(&left_child) {
317 return *subtree.read().unwrap().get(&key).unwrap();
318 }
319
320 let (left, right) = rayon::join(
321 || Self::batch_recalculate(left_child, Arc::clone(&subtree), depth),
322 || Self::batch_recalculate(right_child, Arc::clone(&subtree), depth),
323 );
324
325 let result = H::hash(&[left, right]);
326
327 subtree.write().unwrap().insert(key, result);
328
329 result
330 }
331
332 pub fn proof(&self, index: usize) -> PmtreeResult<MerkleProof<H>> {
334 if index >= self.capacity() {
335 return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
336 }
337
338 let mut witness = Vec::with_capacity(self.depth);
339
340 let mut i = index;
341 let mut depth = self.depth;
342 while depth != 0 {
343 i ^= 1;
344 witness.push((
345 self.get_elem(Key(depth, i))?,
346 (1 - (i & 1)).try_into().unwrap(),
347 ));
348 i >>= 1;
349 depth -= 1;
350 }
351
352 Ok(MerkleProof(witness))
353 }
354
355 pub fn verify(&self, leaf: &H::Fr, witness: &MerkleProof<H>) -> bool {
357 let expected_root = witness.compute_root_from(leaf);
358
359 self.root() == expected_root
360 }
361
362 pub fn get(&self, key: usize) -> PmtreeResult<H::Fr> {
364 if key >= self.capacity() {
365 return Err(PmtreeErrorKind::TreeError(TreeErrorKind::IndexOutOfBounds));
366 }
367
368 self.get_elem(Key(self.depth, key))
369 }
370
371 pub fn root(&self) -> H::Fr {
373 self.root
374 }
375
376 pub fn leaves_set(&self) -> usize {
378 self.next_index
379 }
380
381 pub fn capacity(&self) -> usize {
383 1 << self.depth
384 }
385
386 pub fn depth(&self) -> usize {
388 self.depth
389 }
390}
391
392impl<H: Hasher> MerkleProof<H> {
393 pub fn compute_root_from(&self, leaf: &H::Fr) -> H::Fr {
395 let mut acc = *leaf;
396 for w in self.0.iter() {
397 if w.1 == 0 {
398 acc = H::hash(&[acc, w.0]);
399 } else {
400 acc = H::hash(&[w.0, acc]);
401 }
402 }
403
404 acc
405 }
406
407 pub fn leaf_index(&self) -> usize {
409 self.get_path_index()
410 .into_iter()
411 .rev()
412 .fold(0, |acc, digit| (acc << 1) + usize::from(digit))
413 }
414
415 pub fn get_path_index(&self) -> Vec<u8> {
417 self.0.iter().map(|x| x.1).collect()
418 }
419
420 pub fn get_path_elements(&self) -> Vec<H::Fr> {
422 self.0.iter().map(|x| x.0).collect()
423 }
424
425 pub fn length(&self) -> usize {
427 self.0.len()
428 }
429}