1use crate::hash_pair;
2use crate::unbalanced::UnbalancedMerkleTree;
3use ab_blake3::OUT_LEN;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6#[cfg(feature = "alloc")]
7use alloc::vec::Vec;
8use core::mem;
9use core::mem::MaybeUninit;
10use core::ops::{Deref, DerefMut};
11
12#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
17pub struct MmrPeaks<const MAX_N: u64>
18where
19 [(); MAX_N.ilog2() as usize + 1]:,
20{
21 pub num_leaves: u64,
23 pub peaks: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
26}
27
28impl<const MAX_N: u64> MmrPeaks<MAX_N>
29where
30 [(); MAX_N.ilog2() as usize + 1]:,
31{
32 #[inline(always)]
34 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
35 pub fn num_peaks(&self) -> u8 {
36 self.num_leaves.count_ones() as u8
37 }
38}
39
40#[derive(Debug, Copy, Clone)]
44#[repr(C, align(8))]
45pub struct MerkleMountainRangeBytes<const MAX_N: u64>(
46 [u8; merkle_mountain_range_bytes_size(MAX_N)],
47)
48where
49 [(); merkle_mountain_range_bytes_size(MAX_N)]:;
50
51impl<const MAX_N: u64> Default for MerkleMountainRangeBytes<MAX_N>
52where
53 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
54{
55 #[inline(always)]
56 fn default() -> Self {
57 Self([0; _])
58 }
59}
60
61impl<const MAX_N: u64> From<[u8; merkle_mountain_range_bytes_size(MAX_N)]>
62 for MerkleMountainRangeBytes<MAX_N>
63where
64 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
65{
66 fn from(value: [u8; merkle_mountain_range_bytes_size(MAX_N)]) -> Self {
67 Self(value)
68 }
69}
70
71impl<const MAX_N: u64> From<MerkleMountainRangeBytes<MAX_N>>
72 for [u8; merkle_mountain_range_bytes_size(MAX_N)]
73where
74 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
75{
76 fn from(value: MerkleMountainRangeBytes<MAX_N>) -> Self {
77 value.0
78 }
79}
80
81impl<const MAX_N: u64> Deref for MerkleMountainRangeBytes<MAX_N>
82where
83 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
84{
85 type Target = [u8; merkle_mountain_range_bytes_size(MAX_N)];
86
87 #[inline(always)]
88 fn deref(&self) -> &Self::Target {
89 &self.0
90 }
91}
92
93impl<const MAX_N: u64> DerefMut for MerkleMountainRangeBytes<MAX_N>
94where
95 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
96{
97 #[inline(always)]
98 fn deref_mut(&mut self) -> &mut Self::Target {
99 &mut self.0
100 }
101}
102
103pub const fn merkle_mountain_range_bytes_size(max_n: u64) -> usize {
105 size_of::<u64>() + OUT_LEN * (max_n.ilog2() as usize + 1)
106}
107
108const _: () = {
109 assert!(size_of::<MerkleMountainRangeBytes<2>>() == merkle_mountain_range_bytes_size(2));
110 assert!(size_of::<MerkleMountainRange<2>>() == merkle_mountain_range_bytes_size(2));
111 assert!(align_of::<MerkleMountainRangeBytes<2>>() == align_of::<MerkleMountainRange<2>>());
112};
113
114#[derive(Debug, Copy, Clone)]
124#[repr(C)]
125pub struct MerkleMountainRange<const MAX_N: u64>
126where
127 [(); MAX_N.ilog2() as usize + 1]:,
128{
129 num_leaves: u64,
130 stack: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
132}
133
134impl<const MAX_N: u64> Default for MerkleMountainRange<MAX_N>
135where
136 [(); MAX_N.ilog2() as usize + 1]:,
137{
138 #[inline(always)]
139 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl<const MAX_N: u64> MerkleMountainRange<MAX_N>
147where
148 [(); MAX_N.ilog2() as usize + 1]:,
149{
150 #[inline(always)]
152 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
153 pub fn new() -> Self {
154 Self {
155 num_leaves: 0,
156 stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
157 }
158 }
159
160 #[inline]
164 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
165 pub fn from_peaks(peaks: &MmrPeaks<MAX_N>) -> Option<Self> {
166 let mut result = Self {
167 num_leaves: peaks.num_leaves,
168 stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
169 };
170
171 let mut stack_bits = peaks.num_leaves;
174 let mut peaks_offset = 0;
175
176 while stack_bits != 0 {
177 let stack_offset = stack_bits.trailing_zeros();
178
179 *result.stack.get_mut(stack_offset as usize)? = *peaks.peaks.get(peaks_offset)?;
180
181 peaks_offset += 1;
182 stack_bits &= !(1 << stack_offset);
184 }
185
186 Some(result)
187 }
188
189 #[inline(always)]
191 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
192 pub fn as_bytes(&self) -> &MerkleMountainRangeBytes<MAX_N>
193 where
194 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
195 {
196 unsafe { mem::transmute(self) }
199 }
200
201 #[inline(always)]
206 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
207 pub unsafe fn from_bytes(bytes: &MerkleMountainRangeBytes<MAX_N>) -> &Self
208 where
209 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
210 {
211 unsafe { mem::transmute(bytes) }
215 }
216
217 #[inline(always)]
219 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
220 pub fn num_leaves(&self) -> u64 {
221 self.num_leaves
222 }
223
224 #[inline]
229 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
230 pub fn root(&self) -> Option<[u8; OUT_LEN]> {
231 if self.num_leaves == 0 {
232 return None;
234 }
235
236 let mut root;
237 let mut stack_bits = self.num_leaves;
238 {
239 let lowest_active_level = stack_bits.trailing_zeros() as usize;
240 root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
242 stack_bits &= !(1 << lowest_active_level);
244 }
245
246 loop {
248 let lowest_active_level = stack_bits.trailing_zeros() as usize;
249
250 if lowest_active_level == u64::BITS as usize {
251 break;
252 }
253
254 stack_bits &= !(1 << lowest_active_level);
256
257 let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
259
260 root = hash_pair(lowest_active_level_item, &root);
261 }
262
263 Some(root)
264 }
265
266 #[inline]
268 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
269 pub fn peaks(&self) -> MmrPeaks<MAX_N> {
270 let mut result = MmrPeaks {
271 num_leaves: self.num_leaves,
272 peaks: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
273 };
274
275 let mut stack_bits = self.num_leaves;
278 let mut peaks_offset = 0;
279 while stack_bits != 0 {
280 let stack_offset = stack_bits.trailing_zeros();
281
282 *unsafe { result.peaks.get_unchecked_mut(peaks_offset) } =
285 *unsafe { self.stack.get_unchecked(stack_offset as usize) };
286
287 peaks_offset += 1;
288 stack_bits &= !(1 << stack_offset);
290 }
291
292 result
293 }
294
295 #[inline]
302 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
303 pub fn add_leaf(&mut self, leaf: &[u8; OUT_LEN]) -> bool {
304 if self.num_leaves >= MAX_N {
307 return false;
308 }
309
310 let mut current = *leaf;
311
312 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
314 for item in self.stack.iter().take(lowest_active_levels) {
315 current = hash_pair(item, ¤t);
316 }
317
318 *unsafe { self.stack.get_unchecked_mut(lowest_active_levels) } = current;
324 self.num_leaves += 1;
325
326 true
327 }
328
329 #[inline]
336 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
337 pub fn add_leaves<'a, Item, Iter>(&mut self, leaves: Iter) -> bool
338 where
339 Item: Into<[u8; OUT_LEN]>,
340 Iter: IntoIterator<Item = Item> + 'a,
341 {
342 for leaf in leaves {
344 if self.num_leaves >= MAX_N {
347 return false;
348 }
349
350 let mut current = leaf.into();
351
352 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
354 for item in self.stack.iter().take(lowest_active_levels) {
355 current = hash_pair(item, ¤t);
356 }
357
358 self.stack[lowest_active_levels] = current;
360 self.num_leaves += 1;
361 }
362
363 true
364 }
365
366 #[inline]
370 #[cfg(feature = "alloc")]
371 pub fn add_leaf_and_compute_proof(
372 &mut self,
373 leaf: &[u8; OUT_LEN],
374 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)> {
375 let mut proof = unsafe {
377 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
378 .assume_init()
379 };
380
381 let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, &mut proof)?;
382
383 let proof_capacity = proof.len();
384 let proof = Box::into_raw(proof);
385 let proof = unsafe {
388 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
389 };
390
391 Some((root, proof))
392 }
393
394 #[inline]
398 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
399 pub fn add_leaf_and_compute_proof_in<'proof>(
400 &mut self,
401 leaf: &[u8; OUT_LEN],
402 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
403 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])> {
404 let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, proof)?;
405
406 let proof = unsafe {
408 proof
409 .split_at_mut_unchecked(proof_length)
410 .0
411 .assume_init_mut()
412 };
413
414 Some((root, proof))
415 }
416
417 #[inline]
418 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
419 pub fn add_leaf_and_compute_proof_inner(
420 &mut self,
421 leaf: &[u8; OUT_LEN],
422 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
423 ) -> Option<([u8; OUT_LEN], usize)> {
424 let mut proof_length = 0;
425
426 let current_target_level;
427 let mut position = self.num_leaves;
428
429 {
430 if self.num_leaves >= MAX_N {
433 return None;
434 }
435
436 let mut current = *leaf;
437
438 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
440
441 for item in self.stack.iter().take(lowest_active_levels) {
442 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
445 proof_length += 1;
446
447 current = hash_pair(item, ¤t);
448
449 position /= 2;
451 }
452
453 current_target_level = lowest_active_levels;
454
455 self.stack[lowest_active_levels] = current;
457 self.num_leaves += 1;
458 }
459
460 let mut root;
461 let mut stack_bits = self.num_leaves;
462
463 {
464 let lowest_active_level = stack_bits.trailing_zeros() as usize;
465 root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
467 stack_bits &= !(1 << lowest_active_level);
469 }
470
471 let mut merged_peaks = false;
474 loop {
475 let lowest_active_level = stack_bits.trailing_zeros() as usize;
476
477 if lowest_active_level == u64::BITS as usize {
478 break;
479 }
480
481 stack_bits &= !(1 << lowest_active_level);
483
484 let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
486
487 if lowest_active_level > current_target_level
488 || (lowest_active_level == current_target_level
489 && !position.is_multiple_of(2)
490 && !merged_peaks)
491 {
492 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
494 proof_length += 1;
495 merged_peaks = false;
496 } else if lowest_active_level == current_target_level {
497 unsafe { proof.get_unchecked_mut(proof_length) }.write(root);
499 proof_length += 1;
500 merged_peaks = false;
501 } else {
502 merged_peaks = true;
504 }
505
506 root = hash_pair(lowest_active_level_item, &root);
508
509 position /= 2;
510 }
511
512 Some((root, proof_length))
513 }
514
515 #[inline]
519 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
520 pub fn verify(
521 root: &[u8; OUT_LEN],
522 proof: &[[u8; OUT_LEN]],
523 leaf_index: u64,
524 leaf: [u8; OUT_LEN],
525 num_leaves: u64,
526 ) -> bool {
527 UnbalancedMerkleTree::verify(root, proof, leaf_index, leaf, num_leaves)
528 }
529}