ab_merkle_tree/
mmr.rs

1use crate::hash_pair;
2use crate::unbalanced::UnbalancedMerkleTree;
3use ab_blake3::OUT_LEN;
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6#[cfg(feature = "alloc")]
7use alloc::vec::Vec;
8use core::mem;
9use core::mem::MaybeUninit;
10
11/// MMR peaks for [`MerkleMountainRange`].
12///
13/// Primarily intended to be used with [`MerkleMountainRange::from_peaks()`], can be sent over the
14/// network, etc.
15#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
16pub struct MmrPeaks<const MAX_N: u64>
17where
18    [(); MAX_N.ilog2() as usize + 1]:,
19{
20    /// Number of leaves in MMR
21    pub num_leaves: u64,
22    /// MMR peaks, first [`Self::num_peaks()`] elements are occupied by values, the rest are ignored
23    /// and do not need to be retained.
24    pub peaks: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
25}
26
27impl<const MAX_N: u64> MmrPeaks<MAX_N>
28where
29    [(); MAX_N.ilog2() as usize + 1]:,
30{
31    /// Number of peaks stored in [`Self::peaks`] that are occupied by actual values
32    #[inline(always)]
33    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
34    pub fn num_peaks(&self) -> u8 {
35        self.num_leaves.count_ones() as u8
36    }
37}
38
39/// Byte representation of [`MerkleMountainRange`] with correct alignment.
40///
41/// Somewhat similar in function to [`MmrPeaks`], but for local use only.
42#[derive(Debug, Copy, Clone)]
43#[repr(C, align(8))]
44pub struct MerkleMountainRangeBytes<const MAX_N: u64>(
45    pub [u8; merkle_mountain_range_bytes_size(MAX_N)],
46)
47where
48    [(); merkle_mountain_range_bytes_size(MAX_N)]:;
49
50/// Size of [`MerkleMountainRange`]/[`MerkleMountainRangeBytes`] in bytes
51pub const fn merkle_mountain_range_bytes_size(max_n: u64) -> usize {
52    size_of::<u64>() + OUT_LEN * (max_n.ilog2() as usize + 1)
53}
54
55const _: () = {
56    assert!(size_of::<MerkleMountainRangeBytes<2>>() == merkle_mountain_range_bytes_size(2));
57    assert!(size_of::<MerkleMountainRange<2>>() == merkle_mountain_range_bytes_size(2));
58    assert!(align_of::<MerkleMountainRangeBytes<2>>() == align_of::<MerkleMountainRange<2>>());
59};
60
61/// Merkle Mountain Range variant that has pre-hashed leaves with arbitrary number of elements.
62///
63/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
64/// identical for both. [`UnbalancedMerkleTree`] is more efficient and should be preferred when
65/// possible, while this data structure is designed for aggregating data incrementally over long
66/// periods of time.
67///
68/// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
69/// usage.
70#[derive(Debug, Copy, Clone)]
71#[repr(C)]
72pub struct MerkleMountainRange<const MAX_N: u64>
73where
74    [(); MAX_N.ilog2() as usize + 1]:,
75{
76    num_leaves: u64,
77    // Stack of intermediate nodes per tree level
78    stack: [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
79}
80
81impl<const MAX_N: u64> Default for MerkleMountainRange<MAX_N>
82where
83    [(); MAX_N.ilog2() as usize + 1]:,
84{
85    #[inline(always)]
86    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92// TODO: Think harder about proof generation and verification API here
93impl<const MAX_N: u64> MerkleMountainRange<MAX_N>
94where
95    [(); MAX_N.ilog2() as usize + 1]:,
96{
97    /// Create an empty instance
98    #[inline(always)]
99    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
100    pub fn new() -> Self {
101        Self {
102            num_leaves: 0,
103            stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
104        }
105    }
106
107    /// Create a new instance from previously collected peaks.
108    ///
109    /// Returns `None` if input is invalid.
110    #[inline]
111    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
112    pub fn from_peaks(peaks: &MmrPeaks<MAX_N>) -> Option<Self> {
113        let mut result = Self {
114            num_leaves: peaks.num_leaves,
115            stack: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
116        };
117
118        // Convert peaks (where all occupied entries are all at the beginning of the list instead)
119        // to stack (where occupied entries are at corresponding offsets)
120        let mut stack_bits = peaks.num_leaves;
121        let mut peaks_offset = 0;
122
123        while stack_bits != 0 {
124            let stack_offset = stack_bits.trailing_zeros();
125
126            *result.stack.get_mut(stack_offset as usize)? = *peaks.peaks.get(peaks_offset)?;
127
128            peaks_offset += 1;
129            // Clear the lowest set bit
130            stack_bits &= !(1 << stack_offset);
131        }
132
133        Some(result)
134    }
135
136    /// Get byte representation of Merkle Mountain Range
137    #[inline(always)]
138    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
139    pub fn as_bytes(&self) -> &MerkleMountainRangeBytes<MAX_N>
140    where
141        [(); merkle_mountain_range_bytes_size(MAX_N)]:,
142    {
143        // SAFETY: Both are `#[repr(C)]`, the same size and alignment as `Self`, all bit patterns
144        // are valid
145        unsafe { mem::transmute(self) }
146    }
147
148    /// Create an instance from byte representation.
149    ///
150    /// # Safety
151    /// Bytes must be previously created by [`Self::as_bytes()`].
152    #[inline(always)]
153    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
154    pub unsafe fn from_bytes(bytes: &MerkleMountainRangeBytes<MAX_N>) -> &Self
155    where
156        [(); merkle_mountain_range_bytes_size(MAX_N)]:,
157    {
158        // SAFETY: Both are `#[repr(C)]`, the same size and alignment as `Self`, all bit patterns
159        // are valid. `::from_bytes()` is an `unsafe` function with correct invariant being a
160        // prerequisite of calling it.
161        unsafe { mem::transmute(bytes) }
162    }
163
164    /// Get number of leaves aggregated in Merkle Mountain Range so far
165    #[inline(always)]
166    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
167    pub fn num_leaves(&self) -> u64 {
168        self.num_leaves
169    }
170
171    /// Calculate the root of Merkle Mountain Range.
172    ///
173    /// In case MMR contains a single leaf hash, that leaf hash is returned, `None` is returned if
174    /// there were no leafs added yet.
175    #[inline]
176    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
177    pub fn root(&self) -> Option<[u8; OUT_LEN]> {
178        if self.num_leaves == 0 {
179            // If no leaves were provided
180            return None;
181        }
182
183        let mut root;
184        let mut stack_bits = self.num_leaves;
185        {
186            let lowest_active_level = stack_bits.trailing_zeros() as usize;
187            // SAFETY: Active level must have been set successfully before, hence it exists
188            root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
189            // Clear lowest active level
190            stack_bits &= !(1 << lowest_active_level);
191        }
192
193        // Hash remaining peaks (if any) of the potentially unbalanced tree together
194        loop {
195            let lowest_active_level = stack_bits.trailing_zeros() as usize;
196
197            if lowest_active_level == u64::BITS as usize {
198                break;
199            }
200
201            // Clear lowest active level for next iteration
202            stack_bits &= !(1 << lowest_active_level);
203
204            // SAFETY: Active level must have been set successfully before, hence it exists
205            let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
206
207            root = hash_pair(lowest_active_level_item, &root);
208        }
209
210        Some(root)
211    }
212
213    /// Get peaks of Merkle Mountain Range
214    #[inline]
215    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
216    pub fn peaks(&self) -> MmrPeaks<MAX_N> {
217        let mut result = MmrPeaks {
218            num_leaves: self.num_leaves,
219            peaks: [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
220        };
221
222        // Convert stack (where occupied entries are at corresponding offsets) to peaks (where all
223        // occupied entries are all at the beginning of the list instead)
224        let mut stack_bits = self.num_leaves;
225        let mut peaks_offset = 0;
226        while stack_bits != 0 {
227            let stack_offset = stack_bits.trailing_zeros();
228
229            // SAFETY: Stack offset is always within the range of stack and peaks, this is
230            // guaranteed by internal invariants of the MMR
231            *unsafe { result.peaks.get_unchecked_mut(peaks_offset) } =
232                *unsafe { self.stack.get_unchecked(stack_offset as usize) };
233
234            peaks_offset += 1;
235            // Clear the lowest set bit
236            stack_bits &= !(1 << stack_offset);
237        }
238
239        result
240    }
241
242    /// Add leaf to Merkle Mountain Range.
243    ///
244    /// There is a more efficient version [`Self::add_leaves()`] in case multiple leaves are
245    /// available.
246    ///
247    /// Returns `true` on success, `false` if too many leafs were added.
248    #[inline]
249    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
250    pub fn add_leaf(&mut self, leaf: &[u8; OUT_LEN]) -> bool {
251        // How many leaves were processed so far
252        if self.num_leaves >= MAX_N {
253            return false;
254        }
255
256        let mut current = *leaf;
257
258        // Every bit set to `1` corresponds to an active Merkle Tree level
259        let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
260        for item in self.stack.iter().take(lowest_active_levels) {
261            current = hash_pair(item, &current);
262        }
263
264        // Place the current hash at the first inactive level
265        // SAFETY: Stack is statically guaranteed to support all active levels with number of leaves
266        // checked at the beginning of the function.
267        // In fact the same exact code in `add_leaves()` doesn't require unchecked access, but here
268        // compiler is somehow unable to prove that panic can't happen otherwise.
269        *unsafe { self.stack.get_unchecked_mut(lowest_active_levels) } = current;
270        self.num_leaves += 1;
271
272        true
273    }
274
275    /// Add many leaves to Merkle Mountain Range.
276    ///
277    /// This is a more efficient version of [`Self::add_leaf()`] in case multiple leaves are
278    /// available.
279    ///
280    /// Returns `true` on success, `false` if too many leaves were added.
281    #[inline]
282    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
283    pub fn add_leaves<'a, Item, Iter>(&mut self, leaves: Iter) -> bool
284    where
285        Item: Into<[u8; OUT_LEN]>,
286        Iter: IntoIterator<Item = Item> + 'a,
287    {
288        // TODO: This can be optimized further
289        for leaf in leaves {
290            // How many leaves were processed so far
291            if self.num_leaves >= MAX_N {
292                return false;
293            }
294
295            let mut current = leaf.into();
296
297            // Every bit set to `1` corresponds to an active Merkle Tree level
298            let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
299            for item in self.stack.iter().take(lowest_active_levels) {
300                current = hash_pair(item, &current);
301            }
302
303            // Place the current hash at the first inactive level
304            self.stack[lowest_active_levels] = current;
305            self.num_leaves += 1;
306        }
307
308        true
309    }
310
311    /// Add leaf to Merkle Mountain Range and generate inclusion proof.
312    ///
313    /// Returns `Some((root, proof))` on success, `None` if too many leafs were added.
314    #[inline]
315    #[cfg(feature = "alloc")]
316    pub fn add_leaf_and_compute_proof(
317        &mut self,
318        leaf: &[u8; OUT_LEN],
319    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)> {
320        // SAFETY: Inner value is `MaybeUninit`
321        let mut proof = unsafe {
322            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
323                .assume_init()
324        };
325
326        let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, &mut proof)?;
327
328        let proof_capacity = proof.len();
329        let proof = Box::into_raw(proof);
330        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
331        // initialized
332        let proof = unsafe {
333            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
334        };
335
336        Some((root, proof))
337    }
338
339    /// Add leaf to Merkle Mountain Range and generate inclusion proof.
340    ///
341    /// Returns `Some((root, proof))` on success, `None` if too many leafs were added.
342    #[inline]
343    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
344    pub fn add_leaf_and_compute_proof_in<'proof>(
345        &mut self,
346        leaf: &[u8; OUT_LEN],
347        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
348    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])> {
349        let (root, proof_length) = self.add_leaf_and_compute_proof_inner(leaf, proof)?;
350
351        // SAFETY: Just correctly initialized `proof_length` elements
352        let proof = unsafe {
353            proof
354                .split_at_mut_unchecked(proof_length)
355                .0
356                .assume_init_mut()
357        };
358
359        Some((root, proof))
360    }
361
362    #[inline]
363    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
364    pub fn add_leaf_and_compute_proof_inner(
365        &mut self,
366        leaf: &[u8; OUT_LEN],
367        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
368    ) -> Option<([u8; OUT_LEN], usize)> {
369        let mut proof_length = 0;
370
371        let current_target_level;
372        let mut position = self.num_leaves;
373
374        {
375            // How many leaves were processed so far
376            if self.num_leaves >= MAX_N {
377                return None;
378            }
379
380            let mut current = *leaf;
381
382            // Every bit set to `1` corresponds to an active Merkle Tree level
383            let lowest_active_levels = self.num_leaves.trailing_ones() as usize;
384
385            for item in self.stack.iter().take(lowest_active_levels) {
386                // If at the target leaf index, need to collect the proof
387                // SAFETY: Method signature guarantees upper bound of the proof length
388                unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
389                proof_length += 1;
390
391                current = hash_pair(item, &current);
392
393                // Move up the tree
394                position /= 2;
395            }
396
397            current_target_level = lowest_active_levels;
398
399            // Place the current hash at the first inactive level
400            self.stack[lowest_active_levels] = current;
401            self.num_leaves += 1;
402        }
403
404        let mut root;
405        let mut stack_bits = self.num_leaves;
406
407        {
408            let lowest_active_level = stack_bits.trailing_zeros() as usize;
409            // SAFETY: Active level must have been set successfully before, hence it exists
410            root = *unsafe { self.stack.get_unchecked(lowest_active_level) };
411            // Clear lowest active level
412            stack_bits &= !(1 << lowest_active_level);
413        }
414
415        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
416        // proof hashes
417        let mut merged_peaks = false;
418        loop {
419            let lowest_active_level = stack_bits.trailing_zeros() as usize;
420
421            if lowest_active_level == u64::BITS as usize {
422                break;
423            }
424
425            // Clear lowest active level for next iteration
426            stack_bits &= !(1 << lowest_active_level);
427
428            // SAFETY: Active level must have been set successfully before, hence it exists
429            let lowest_active_level_item = unsafe { self.stack.get_unchecked(lowest_active_level) };
430
431            if lowest_active_level > current_target_level
432                || (lowest_active_level == current_target_level
433                    && (position % 2 != 0)
434                    && !merged_peaks)
435            {
436                // SAFETY: Method signature guarantees upper bound of the proof length
437                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
438                proof_length += 1;
439                merged_peaks = false;
440            } else if lowest_active_level == current_target_level {
441                // SAFETY: Method signature guarantees upper bound of the proof length
442                unsafe { proof.get_unchecked_mut(proof_length) }.write(root);
443                proof_length += 1;
444                merged_peaks = false;
445            } else {
446                // Not collecting proof because of the need to merge peaks of an unbalanced tree
447                merged_peaks = true;
448            }
449
450            // Collect the lowest peak into the proof
451            root = hash_pair(lowest_active_level_item, &root);
452
453            position /= 2;
454        }
455
456        Some((root, proof_length))
457    }
458
459    /// Verify a Merkle proof for a leaf at the given index.
460    ///
461    /// NOTE: `MAX_N` constant doesn't matter here and can be anything that is `>= 1`.
462    #[inline]
463    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
464    pub fn verify(
465        root: &[u8; OUT_LEN],
466        proof: &[[u8; OUT_LEN]],
467        leaf_index: u64,
468        leaf: [u8; OUT_LEN],
469        num_leaves: u64,
470    ) -> bool {
471        UnbalancedMerkleTree::verify(root, proof, leaf_index, leaf, num_leaves)
472    }
473}