ab_merkle_tree/
mmr.rs

1use crate::hash_pair;
2use crate::unbalanced::UnbalancedMerkleTree;
3use ab_blake3::OUT_LEN;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6#[cfg(feature = "alloc")]
7use alloc::vec::Vec;
8use core::mem;
9use core::mem::MaybeUninit;
10use core::ops::{Deref, DerefMut};
11
12/// MMR peaks for [`MerkleMountainRange`].
13///
14/// Primarily intended to be used with [`MerkleMountainRange::from_peaks()`], can be sent over the
15/// network, etc.
16#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
17pub struct MmrPeaks<const MAX_N: u64>
18where
19    [(); MAX_N.ilog2() as usize + 1]:,
20{
21    /// Number of leaves in MMR
22    pub num_leaves: u64,
23    /// MMR peaks, first [`Self::num_peaks()`] elements are occupied by values, the rest are ignored
24    /// and do not need to be retained.
25    pub peaks: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
26}
27
28impl<const MAX_N: u64> MmrPeaks<MAX_N>
29where
30    [(); MAX_N.ilog2() as usize + 1]:,
31{
32    /// Number of peaks stored in [`Self::peaks`] that are occupied by actual values
33    #[inline(always)]
34    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
35    pub fn num_peaks(&self) -> u8 {
36        self.num_leaves.count_ones() as u8
37    }
38}
39
40/// Byte representation of [`MerkleMountainRange`] with correct alignment.
41///
42/// Somewhat similar in function to [`MmrPeaks`], but for local use only.
43#[derive(Debug, Copy, Clone)]
44#[repr(C, align(8))]
45pub struct MerkleMountainRangeBytes<const MAX_N: u64>(
46    [u8; merkle_mountain_range_bytes_size(MAX_N)],
47)
48where
49    [(); merkle_mountain_range_bytes_size(MAX_N)]:;
50
51impl<const MAX_N: u64> Default for MerkleMountainRangeBytes<MAX_N>
52where
53    [(); merkle_mountain_range_bytes_size(MAX_N)]:,
54{
55    #[inline(always)]
56    fn default() -> Self {
57        Self([0; _])
58    }
59}
60
61impl<const MAX_N: u64> From<[u8; merkle_mountain_range_bytes_size(MAX_N)]>
62    for MerkleMountainRangeBytes<MAX_N>
63where
64    [(); merkle_mountain_range_bytes_size(MAX_N)]:,
65{
66    fn from(value: [u8; merkle_mountain_range_bytes_size(MAX_N)]) -> Self {
67        Self(value)
68    }
69}
70
71impl<const MAX_N: u64> From<MerkleMountainRangeBytes<MAX_N>>
72    for [u8; merkle_mountain_range_bytes_size(MAX_N)]
73where
74    [(); merkle_mountain_range_bytes_size(MAX_N)]:,
75{
76    fn from(value: MerkleMountainRangeBytes<MAX_N>) -> Self {
77        value.0
78    }
79}
80
81impl<const MAX_N: u64> Deref for MerkleMountainRangeBytes<MAX_N>
82where
83    [(); merkle_mountain_range_bytes_size(MAX_N)]:,
84{
85    type Target = [u8; merkle_mountain_range_bytes_size(MAX_N)];
86
87    #[inline(always)]
88    fn deref(&self) -> &Self::Target {
89        &self.0
90    }
91}
92
93impl<const MAX_N: u64> DerefMut for MerkleMountainRangeBytes<MAX_N>
94where
95    [(); merkle_mountain_range_bytes_size(MAX_N)]:,
96{
97    #[inline(always)]
98    fn deref_mut(&mut self) -> &mut Self::Target {
99        &mut self.0
100    }
101}
102
103/// Size of [`MerkleMountainRange`]/[`MerkleMountainRangeBytes`] in bytes
104pub const fn merkle_mountain_range_bytes_size(max_n: u64) -> usize {
105    size_of::<u64>() + OUT_LEN * (max_n.ilog2() as usize + 1)
106}
107
108const _: () = {
109    assert!(size_of::<MerkleMountainRangeBytes<2>>() == merkle_mountain_range_bytes_size(2));
110    assert!(size_of::<MerkleMountainRange<2>>() == merkle_mountain_range_bytes_size(2));
111    assert!(align_of::<MerkleMountainRangeBytes<2>>() == align_of::<MerkleMountainRange<2>>());
112};
113
114/// Merkle Mountain Range variant that has pre-hashed leaves with arbitrary number of elements.
115///
116/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
117/// identical for both. [`UnbalancedMerkleTree`] is more efficient and should be preferred when
118/// possible, while this data structure is designed for aggregating data incrementally over long
119/// periods of time.
120///
121/// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
122/// usage.
123#[derive(Debug, Copy, Clone)]
124#[repr(C)]
125pub struct MerkleMountainRange<const MAX_N: u64>
126where
127    [(); MAX_N.ilog2() as usize + 1]:,
128{
129    num_leaves: u64,
130    // Stack of intermediate nodes per tree level
131    stack: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
132}
133
134impl<const MAX_N: u64> Default for MerkleMountainRange<MAX_N>
135where
136    [(); MAX_N.ilog2() as usize + 1]:,
137{
138    #[inline(always)]
139    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145// TODO: Think harder about proof generation and verification API here
146impl<const MAX_N: u64> MerkleMountainRange<MAX_N>
147where
148    [(); MAX_N.ilog2() as usize + 1]:,
149{
150    /// Create an empty instance
151    #[inline(always)]
152    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
153    pub fn new() -> Self {
154        Self {
155            num_leaves: 0,
156            stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
157        }
158    }
159
160    /// Create a new instance from previously collected peaks.
161    ///
162    /// Returns `None` if input is invalid.
163    #[inline]
164    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
165    pub fn from_peaks(peaks: &MmrPeaks<MAX_N>) -> Option<Self> {
166        let mut result = Self {
167            num_leaves: peaks.num_leaves,
168            stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
169        };
170
171        // Convert peaks (where all occupied entries are all at the beginning of the list instead)
172        // to stack (where occupied entries are at corresponding offsets)
173        let mut stack_bits = peaks.num_leaves;
174        let mut peaks_offset = 0;
175
176        while stack_bits != 0 {
177            let stack_offset = stack_bits.trailing_zeros();
178
179            *result.stack.get_mut(stack_offset as usize)? = *peaks.peaks.get(peaks_offset)?;
180
181            peaks_offset += 1;
182            // Clear the lowest set bit
183            stack_bits &= !(1 << stack_offset);
184        }
185
186        Some(result)
187    }
188
189    /// Get byte representation of Merkle Mountain Range
190    #[inline(always)]
191    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
192    pub fn as_bytes(&self) -> &MerkleMountainRangeBytes<MAX_N>
193    where
194        [(); merkle_mountain_range_bytes_size(MAX_N)]:,
195    {
196        // SAFETY: Both are `#[repr(C)]`, the same size and alignment as `Self`, all bit patterns
197        // are valid
198        unsafe { mem::transmute(self) }
199    }
200
201    /// Create an instance from byte representation.
202    ///
203    /// # Safety
204    /// Bytes must be previously created by [`Self::as_bytes()`].
205    #[inline(always)]
206    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
207    pub unsafe fn from_bytes(bytes: &MerkleMountainRangeBytes<MAX_N>) -> &Self
208    where
209        [(); merkle_mountain_range_bytes_size(MAX_N)]:,
210    {
211        // SAFETY: Both are `#[repr(C)]`, the same size and alignment as `Self`, all bit patterns
212        // are valid. `::from_bytes()` is an `unsafe` function with correct invariant being a
213        // prerequisite of calling it.
214        unsafe { mem::transmute(bytes) }
215    }
216
217    /// Get number of leaves aggregated in Merkle Mountain Range so far
218    #[inline(always)]
219    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
220    pub fn num_leaves(&self) -> u64 {
221        self.num_leaves
222    }
223
224    /// Calculate the root of Merkle Mountain Range.
225    ///
226    /// In case MMR contains a single leaf hash, that leaf hash is returned, `None` is returned if
227    /// there were no leaves added yet.
228    #[inline]
229    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
230    pub fn root(&self) -> Option<[u8; OUT_LEN]> {
231        if self.num_leaves == 0 {
232            // If no leaves were provided
233            return None;
234        }
235
236        let mut root;
237        let mut stack_bits = self.num_leaves;
238        {
239            let lowest_active_level = stack_bits.trailing_zeros() as usize;
240            // SAFETY: Active level must have been set successfully before, hence it exists
241            root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
242            // Clear lowest active level
243            stack_bits &= !(1 << lowest_active_level);
244        }
245
246        // Hash remaining peaks (if any) of the potentially unbalanced tree together
247        loop {
248            let lowest_active_level = stack_bits.trailing_zeros() as usize;
249
250            if lowest_active_level == u64::BITS as usize {
251                break;
252            }
253
254            // Clear lowest active level for next iteration
255            stack_bits &= !(1 << lowest_active_level);
256
257            // SAFETY: Active level must have been set successfully before, hence it exists
258            let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
259
260            root = hash_pair(lowest_active_level_item, &root);
261        }
262
263        Some(root)
264    }
265
266    /// Get peaks of Merkle Mountain Range
267    #[inline]
268    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
269    pub fn peaks(&self) -> MmrPeaks<MAX_N> {
270        let mut result = MmrPeaks {
271            num_leaves: self.num_leaves,
272            peaks: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
273        };
274
275        // Convert stack (where occupied entries are at corresponding offsets) to peaks (where all
276        // occupied entries are all at the beginning of the list instead)
277        let mut stack_bits = self.num_leaves;
278        let mut peaks_offset = 0;
279        while stack_bits != 0 {
280            let stack_offset = stack_bits.trailing_zeros();
281
282            // SAFETY: Stack offset is always within the range of stack and peaks, this is
283            // guaranteed by internal invariants of the MMR
284            *unsafe { result.peaks.get_unchecked_mut(peaks_offset) } =
285                *unsafe { self.stack.get_unchecked(stack_offset as usize) };
286
287            peaks_offset += 1;
288            // Clear the lowest set bit
289            stack_bits &= !(1 << stack_offset);
290        }
291
292        result
293    }
294
295    /// Add leaf to Merkle Mountain Range.
296    ///
297    /// There is a more efficient version [`Self::add_leaves()`] in case multiple leaves are
298    /// available.
299    ///
300    /// Returns `true` on success, `false` if too many leaves were added.
301    #[inline]
302    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
303    pub fn add_leaf(&mut self, leaf: &[u8; OUT_LEN]) -> bool {
304        // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but `>=`
305        // helps compiler with panic safety checks.
306        if self.num_leaves >= MAX_N {
307            return false;
308        }
309
310        let mut current = *leaf;
311
312        // Every bit set to `1` corresponds to an active Merkle Tree level
313        let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
314        for item in self.stack.iter().take(lowest_active_levels) {
315            current = hash_pair(item, &current);
316        }
317
318        // Place the current hash at the first inactive level
319        // SAFETY: Stack is statically guaranteed to support all active levels with number of leaves
320        // checked at the beginning of the function.
321        // In fact the same exact code in `add_leaves()` doesn't require unchecked access, but here
322        // compiler is somehow unable to prove that panic can't happen otherwise.
323        *unsafe { self.stack.get_unchecked_mut(lowest_active_levels) } = current;
324        self.num_leaves += 1;
325
326        true
327    }
328
329    /// Add many leaves to Merkle Mountain Range.
330    ///
331    /// This is a more efficient version of [`Self::add_leaf()`] in case multiple leaves are
332    /// available.
333    ///
334    /// Returns `true` on success, `false` if too many leaves were added.
335    #[inline]
336    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
337    pub fn add_leaves<'a, Item, Iter>(&mut self, leaves: Iter) -> bool
338    where
339        Item: Into<[u8; OUT_LEN]>,
340        Iter: IntoIterator<Item = Item> + 'a,
341    {
342        // TODO: This can be optimized further
343        for leaf in leaves {
344            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
345            // `>=` helps compiler with panic safety checks.
346            if self.num_leaves >= MAX_N {
347                return false;
348            }
349
350            let mut current = leaf.into();
351
352            // Every bit set to `1` corresponds to an active Merkle Tree level
353            let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
354            for item in self.stack.iter().take(lowest_active_levels) {
355                current = hash_pair(item, &current);
356            }
357
358            // Place the current hash at the first inactive level
359            self.stack[lowest_active_levels] = current;
360            self.num_leaves += 1;
361        }
362
363        true
364    }
365
366    /// Add leaf to Merkle Mountain Range and generate inclusion proof.
367    ///
368    /// Returns `Some((root, proof))` on success, `None` if too many leaves were added.
369    #[inline]
370    #[cfg(feature = "alloc")]
371    pub fn add_leaf_and_compute_proof(
372        &mut self,
373        leaf: &[u8; OUT_LEN],
374    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)> {
375        // SAFETY: Inner value is `MaybeUninit`
376        let mut proof = unsafe {
377            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
378                .assume_init()
379        };
380
381        let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, &mut proof)?;
382
383        let proof_capacity = proof.len();
384        let proof = Box::into_raw(proof);
385        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
386        // initialized
387        let proof = unsafe {
388            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
389        };
390
391        Some((root, proof))
392    }
393
394    /// Add leaf to Merkle Mountain Range and generate inclusion proof.
395    ///
396    /// Returns `Some((root, proof))` on success, `None` if too many leaves were added.
397    #[inline]
398    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
399    pub fn add_leaf_and_compute_proof_in<'proof>(
400        &mut self,
401        leaf: &[u8; OUT_LEN],
402        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
403    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])> {
404        let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, proof)?;
405
406        // SAFETY: Just correctly initialized `proof_length` elements
407        let proof = unsafe {
408            proof
409                .split_at_mut_unchecked(proof_length)
410                .0
411                .assume_init_mut()
412        };
413
414        Some((root, proof))
415    }
416
417    #[inline]
418    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
419    pub fn add_leaf_and_compute_proof_inner(
420        &mut self,
421        leaf: &[u8; OUT_LEN],
422        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
423    ) -> Option<([u8; OUT_LEN], usize)> {
424        let mut proof_length = 0;
425
426        let current_target_level;
427        let mut position = self.num_leaves;
428
429        {
430            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
431            // `>=` helps compiler with panic safety checks.
432            if self.num_leaves >= MAX_N {
433                return None;
434            }
435
436            let mut current = *leaf;
437
438            // Every bit set to `1` corresponds to an active Merkle Tree level
439            let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
440
441            for item in self.stack.iter().take(lowest_active_levels) {
442                // If at the target leaf index, need to collect the proof
443                // SAFETY: Method signature guarantees upper bound of the proof length
444                unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
445                proof_length += 1;
446
447                current = hash_pair(item, &current);
448
449                // Move up the tree
450                position /= 2;
451            }
452
453            current_target_level = lowest_active_levels;
454
455            // Place the current hash at the first inactive level
456            self.stack[lowest_active_levels] = current;
457            self.num_leaves += 1;
458        }
459
460        let mut root;
461        let mut stack_bits = self.num_leaves;
462
463        {
464            let lowest_active_level = stack_bits.trailing_zeros() as usize;
465            // SAFETY: Active level must have been set successfully before, hence it exists
466            root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
467            // Clear lowest active level
468            stack_bits &= !(1 << lowest_active_level);
469        }
470
471        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
472        // proof hashes
473        let mut merged_peaks = false;
474        loop {
475            let lowest_active_level = stack_bits.trailing_zeros() as usize;
476
477            if lowest_active_level == u64::BITS as usize {
478                break;
479            }
480
481            // Clear lowest active level for next iteration
482            stack_bits &= !(1 << lowest_active_level);
483
484            // SAFETY: Active level must have been set successfully before, hence it exists
485            let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
486
487            if lowest_active_level > current_target_level
488                || (lowest_active_level == current_target_level
489                    && !position.is_multiple_of(2)
490                    && !merged_peaks)
491            {
492                // SAFETY: Method signature guarantees upper bound of the proof length
493                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
494                proof_length += 1;
495                merged_peaks = false;
496            } else if lowest_active_level == current_target_level {
497                // SAFETY: Method signature guarantees upper bound of the proof length
498                unsafe { proof.get_unchecked_mut(proof_length) }.write(root);
499                proof_length += 1;
500                merged_peaks = false;
501            } else {
502                // Not collecting proof because of the need to merge peaks of an unbalanced tree
503                merged_peaks = true;
504            }
505
506            // Collect the lowest peak into the proof
507            root = hash_pair(lowest_active_level_item, &root);
508
509            position /= 2;
510        }
511
512        Some((root, proof_length))
513    }
514
515    /// Verify a Merkle proof for a leaf at the given index.
516    ///
517    /// NOTE: `MAX_N` constant doesn't matter here and can be anything that is `>= 1`.
518    #[inline]
519    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
520    pub fn verify(
521        root: &[u8; OUT_LEN],
522        proof: &[[u8; OUT_LEN]],
523        leaf_index: u64,
524        leaf: [u8; OUT_LEN],
525        num_leaves: u64,
526    ) -> bool {
527        UnbalancedMerkleTree::verify(root, proof, leaf_index, leaf, num_leaves)
528    }
529}