1use crate::hash_pair;
2use crate::unbalanced::UnbalancedMerkleTree;
3use ab_blake3::OUT_LEN;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6#[cfg(feature = "alloc")]
7use alloc::vec::Vec;
8use core::mem;
9use core::mem::MaybeUninit;
10
11#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
16pub struct MmrPeaks<const MAX_N: u64>
17where
18 [(); MAX_N.ilog2() as usize + 1]:,
19{
20 pub num_leaves: u64,
22 pub peaks: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
25}
26
27impl<const MAX_N: u64> MmrPeaks<MAX_N>
28where
29 [(); MAX_N.ilog2() as usize + 1]:,
30{
31 #[inline(always)]
33 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
34 pub fn num_peaks(&self) -> u8 {
35 self.num_leaves.count_ones() as u8
36 }
37}
38
39#[derive(Debug, Copy, Clone)]
43#[repr(C, align(8))]
44pub struct MerkleMountainRangeBytes<const MAX_N: u64>(
45 pub [u8; merkle_mountain_range_bytes_size(MAX_N)],
46)
47where
48 [(); merkle_mountain_range_bytes_size(MAX_N)]:;
49
50pub const fn merkle_mountain_range_bytes_size(max_n: u64) -> usize {
52 size_of::<u64>() + OUT_LEN * (max_n.ilog2() as usize + 1)
53}
54
55const _: () = {
56 assert!(size_of::<MerkleMountainRangeBytes<2>>() == merkle_mountain_range_bytes_size(2));
57 assert!(size_of::<MerkleMountainRange<2>>() == merkle_mountain_range_bytes_size(2));
58 assert!(align_of::<MerkleMountainRangeBytes<2>>() == align_of::<MerkleMountainRange<2>>());
59};
60
61#[derive(Debug, Copy, Clone)]
71#[repr(C)]
72pub struct MerkleMountainRange<const MAX_N: u64>
73where
74 [(); MAX_N.ilog2() as usize + 1]:,
75{
76 num_leaves: u64,
77 stack: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
79}
80
81impl<const MAX_N: u64> Default for MerkleMountainRange<MAX_N>
82where
83 [(); MAX_N.ilog2() as usize + 1]:,
84{
85 #[inline(always)]
86 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl<const MAX_N: u64> MerkleMountainRange<MAX_N>
94where
95 [(); MAX_N.ilog2() as usize + 1]:,
96{
97 #[inline(always)]
99 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
100 pub fn new() -> Self {
101 Self {
102 num_leaves: 0,
103 stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
104 }
105 }
106
107 #[inline]
111 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
112 pub fn from_peaks(peaks: &MmrPeaks<MAX_N>) -> Option<Self> {
113 let mut result = Self {
114 num_leaves: peaks.num_leaves,
115 stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
116 };
117
118 let mut stack_bits = peaks.num_leaves;
121 let mut peaks_offset = 0;
122
123 while stack_bits != 0 {
124 let stack_offset = stack_bits.trailing_zeros();
125
126 *result.stack.get_mut(stack_offset as usize)? = *peaks.peaks.get(peaks_offset)?;
127
128 peaks_offset += 1;
129 stack_bits &= !(1 << stack_offset);
131 }
132
133 Some(result)
134 }
135
136 #[inline(always)]
138 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
139 pub fn as_bytes(&self) -> &MerkleMountainRangeBytes<MAX_N>
140 where
141 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
142 {
143 unsafe { mem::transmute(self) }
146 }
147
148 #[inline(always)]
153 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
154 pub unsafe fn from_bytes(bytes: &MerkleMountainRangeBytes<MAX_N>) -> &Self
155 where
156 [(); merkle_mountain_range_bytes_size(MAX_N)]:,
157 {
158 unsafe { mem::transmute(bytes) }
162 }
163
164 #[inline(always)]
166 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
167 pub fn num_leaves(&self) -> u64 {
168 self.num_leaves
169 }
170
171 #[inline]
176 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
177 pub fn root(&self) -> Option<[u8; OUT_LEN]> {
178 if self.num_leaves == 0 {
179 return None;
181 }
182
183 let mut root;
184 let mut stack_bits = self.num_leaves;
185 {
186 let lowest_active_level = stack_bits.trailing_zeros() as usize;
187 root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
189 stack_bits &= !(1 << lowest_active_level);
191 }
192
193 loop {
195 let lowest_active_level = stack_bits.trailing_zeros() as usize;
196
197 if lowest_active_level == u64::BITS as usize {
198 break;
199 }
200
201 stack_bits &= !(1 << lowest_active_level);
203
204 let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
206
207 root = hash_pair(lowest_active_level_item, &root);
208 }
209
210 Some(root)
211 }
212
213 #[inline]
215 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
216 pub fn peaks(&self) -> MmrPeaks<MAX_N> {
217 let mut result = MmrPeaks {
218 num_leaves: self.num_leaves,
219 peaks: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
220 };
221
222 let mut stack_bits = self.num_leaves;
225 let mut peaks_offset = 0;
226 while stack_bits != 0 {
227 let stack_offset = stack_bits.trailing_zeros();
228
229 *unsafe { result.peaks.get_unchecked_mut(peaks_offset) } =
232 *unsafe { self.stack.get_unchecked(stack_offset as usize) };
233
234 peaks_offset += 1;
235 stack_bits &= !(1 << stack_offset);
237 }
238
239 result
240 }
241
242 #[inline]
249 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
250 pub fn add_leaf(&mut self, leaf: &[u8; OUT_LEN]) -> bool {
251 if self.num_leaves >= MAX_N {
253 return false;
254 }
255
256 let mut current = *leaf;
257
258 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
260 for item in self.stack.iter().take(lowest_active_levels) {
261 current = hash_pair(item, ¤t);
262 }
263
264 *unsafe { self.stack.get_unchecked_mut(lowest_active_levels) } = current;
270 self.num_leaves += 1;
271
272 true
273 }
274
275 #[inline]
282 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
283 pub fn add_leaves<'a, Item, Iter>(&mut self, leaves: Iter) -> bool
284 where
285 Item: Into<[u8; OUT_LEN]>,
286 Iter: IntoIterator<Item = Item> + 'a,
287 {
288 for leaf in leaves {
290 if self.num_leaves >= MAX_N {
292 return false;
293 }
294
295 let mut current = leaf.into();
296
297 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
299 for item in self.stack.iter().take(lowest_active_levels) {
300 current = hash_pair(item, ¤t);
301 }
302
303 self.stack[lowest_active_levels] = current;
305 self.num_leaves += 1;
306 }
307
308 true
309 }
310
311 #[inline]
315 #[cfg(feature = "alloc")]
316 pub fn add_leaf_and_compute_proof(
317 &mut self,
318 leaf: &[u8; OUT_LEN],
319 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)> {
320 let mut proof = unsafe {
322 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
323 .assume_init()
324 };
325
326 let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, &mut proof)?;
327
328 let proof_capacity = proof.len();
329 let proof = Box::into_raw(proof);
330 let proof = unsafe {
333 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
334 };
335
336 Some((root, proof))
337 }
338
339 #[inline]
343 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
344 pub fn add_leaf_and_compute_proof_in<'proof>(
345 &mut self,
346 leaf: &[u8; OUT_LEN],
347 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
348 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])> {
349 let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, proof)?;
350
351 let proof = unsafe {
353 proof
354 .split_at_mut_unchecked(proof_length)
355 .0
356 .assume_init_mut()
357 };
358
359 Some((root, proof))
360 }
361
362 #[inline]
363 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
364 pub fn add_leaf_and_compute_proof_inner(
365 &mut self,
366 leaf: &[u8; OUT_LEN],
367 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
368 ) -> Option<([u8; OUT_LEN], usize)> {
369 let mut proof_length = 0;
370
371 let current_target_level;
372 let mut position = self.num_leaves;
373
374 {
375 if self.num_leaves >= MAX_N {
377 return None;
378 }
379
380 let mut current = *leaf;
381
382 let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
384
385 for item in self.stack.iter().take(lowest_active_levels) {
386 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
389 proof_length += 1;
390
391 current = hash_pair(item, ¤t);
392
393 position /= 2;
395 }
396
397 current_target_level = lowest_active_levels;
398
399 self.stack[lowest_active_levels] = current;
401 self.num_leaves += 1;
402 }
403
404 let mut root;
405 let mut stack_bits = self.num_leaves;
406
407 {
408 let lowest_active_level = stack_bits.trailing_zeros() as usize;
409 root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
411 stack_bits &= !(1 << lowest_active_level);
413 }
414
415 let mut merged_peaks = false;
418 loop {
419 let lowest_active_level = stack_bits.trailing_zeros() as usize;
420
421 if lowest_active_level == u64::BITS as usize {
422 break;
423 }
424
425 stack_bits &= !(1 << lowest_active_level);
427
428 let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
430
431 if lowest_active_level > current_target_level
432 || (lowest_active_level == current_target_level
433 && (position % 2 != 0)
434 && !merged_peaks)
435 {
436 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
438 proof_length += 1;
439 merged_peaks = false;
440 } else if lowest_active_level == current_target_level {
441 unsafe { proof.get_unchecked_mut(proof_length) }.write(root);
443 proof_length += 1;
444 merged_peaks = false;
445 } else {
446 merged_peaks = true;
448 }
449
450 root = hash_pair(lowest_active_level_item, &root);
452
453 position /= 2;
454 }
455
456 Some((root, proof_length))
457 }
458
459 #[inline]
463 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
464 pub fn verify(
465 root: &[u8; OUT_LEN],
466 proof: &[[u8; OUT_LEN]],
467 leaf_index: u64,
468 leaf: [u8; OUT_LEN],
469 num_leaves: u64,
470 ) -> bool {
471 UnbalancedMerkleTree::verify(root, proof, leaf_index, leaf, num_leaves)
472 }
473}