ab_merkle_tree/
unbalanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedMerkleTree`] is
13/// more efficient and should be preferred when possible.
14///
15/// [`BalancedMerkleTree`]: crate::balanced::BalancedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    ///
41    /// Returns `None` for an empty list of leaves.
42    #[inline]
43    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44    pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45        leaves: Iter,
46    ) -> Option<[u8; OUT_LEN]>
47    where
48        [(); MAX_N.ilog2() as usize + 1]:,
49        Item: Into<[u8; OUT_LEN]>,
50        Iter: IntoIterator<Item = Item> + 'a,
51    {
52        // Stack of intermediate nodes per tree level
53        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
54        let mut num_leaves = 0_u64;
55
56        for hash in leaves {
57            // How many leaves were processed so far
58            if num_leaves >= MAX_N {
59                return None;
60            }
61
62            let mut current = hash.into();
63
64            // Every bit set to `1` corresponds to an active Merkle Tree level
65            let lowest_active_levels = num_leaves.trailing_ones() as usize;
66            for item in stack.iter().take(lowest_active_levels) {
67                current = hash_pair(item, &current);
68            }
69
70            // Place the current hash at the first inactive level
71            stack[lowest_active_levels] = current;
72            num_leaves += 1;
73        }
74
75        if num_leaves == 0 {
76            // If no leaves were provided
77            return None;
78        }
79
80        let mut stack_bits = num_leaves;
81
82        {
83            let lowest_active_level = stack_bits.trailing_zeros() as usize;
84            // Reuse `stack[0]` for resulting value
85            // SAFETY: Active level must have been set successfully before, hence it exists
86            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
87            // Clear lowest active level
88            stack_bits &= !(1 << lowest_active_level);
89        }
90
91        // Hash remaining peaks (if any) of the potentially unbalanced tree together
92        loop {
93            let lowest_active_level = stack_bits.trailing_zeros() as usize;
94
95            if lowest_active_level == u64::BITS as usize {
96                break;
97            }
98
99            // Clear lowest active level for next iteration
100            stack_bits &= !(1 << lowest_active_level);
101
102            // SAFETY: Active level must have been set successfully before, hence it exists
103            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
104
105            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
106        }
107
108        Some(stack[0])
109    }
110
111    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
112    ///
113    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
114    ///
115    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
116    /// usage.
117    #[inline]
118    #[cfg(feature = "alloc")]
119    pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
120        leaves: Iter,
121        target_index: usize,
122    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
123    where
124        [(); MAX_N.ilog2() as usize + 1]:,
125        Item: Into<[u8; OUT_LEN]>,
126        Iter: IntoIterator<Item = Item> + 'a,
127    {
128        // Stack of intermediate nodes per tree level
129        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
130        // SAFETY: Inner value is `MaybeUninit`
131        let mut proof = unsafe {
132            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
133                .assume_init()
134        };
135
136        let (root, proof_length) =
137            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
138
139        let proof_capacity = proof.len();
140        let proof = Box::into_raw(proof);
141        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
142        // initialized
143        let proof = unsafe {
144            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
145        };
146
147        Some((root, proof))
148    }
149
150    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
151    ///
152    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
153    ///
154    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
155    /// usage.
156    #[inline]
157    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
158    pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
159        leaves: Iter,
160        target_index: usize,
161        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
162    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
163    where
164        [(); MAX_N.ilog2() as usize + 1]:,
165        Item: Into<[u8; OUT_LEN]>,
166        Iter: IntoIterator<Item = Item> + 'a,
167    {
168        // Stack of intermediate nodes per tree level
169        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
170
171        let (root, proof_length) =
172            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
173        // SAFETY: Just correctly initialized `proof_length` elements
174        let proof = unsafe {
175            proof
176                .split_at_mut_unchecked(proof_length)
177                .0
178                .assume_init_mut()
179        };
180
181        Some((root, proof))
182    }
183
184    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
185    fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
186        leaves: Iter,
187        target_index: usize,
188        stack: &mut [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
189        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
190    ) -> Option<([u8; OUT_LEN], usize)>
191    where
192        [(); MAX_N.ilog2() as usize + 1]:,
193        Item: Into<[u8; OUT_LEN]>,
194        Iter: IntoIterator<Item = Item> + 'a,
195    {
196        let mut proof_length = 0;
197        let mut num_leaves = 0_u64;
198
199        let mut current_target_level = None;
200        let mut position = target_index;
201
202        for (current_index, hash) in leaves.into_iter().enumerate() {
203            // How many leaves were processed so far
204            if num_leaves >= MAX_N {
205                return None;
206            }
207
208            let mut current = hash.into();
209
210            // Every bit set to `1` corresponds to an active Merkle Tree level
211            let lowest_active_levels = num_leaves.trailing_ones() as usize;
212
213            if current_index == target_index {
214                for item in stack.iter().take(lowest_active_levels) {
215                    // If at the target leaf index, need to collect the proof
216                    // SAFETY: Method signature guarantees upper bound of the proof length
217                    unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
218                    proof_length += 1;
219
220                    current = hash_pair(item, &current);
221
222                    // Move up the tree
223                    position /= 2;
224                }
225
226                current_target_level = Some(lowest_active_levels);
227            } else {
228                for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
229                    if current_target_level == Some(level) {
230                        // SAFETY: Method signature guarantees upper bound of the proof length
231                        unsafe { proof.get_unchecked_mut(proof_length) }
232                            .write(if position % 2 == 0 { current } else { *item });
233                        proof_length += 1;
234
235                        current_target_level = Some(level + 1);
236
237                        // Move up the tree
238                        position /= 2;
239                    }
240
241                    current = hash_pair(item, &current);
242                }
243            }
244
245            // Place the current hash at the first inactive level
246            stack[lowest_active_levels] = current;
247            num_leaves += 1;
248        }
249
250        // `active_levels` here contains the number of leaves after above loop
251        if target_index >= num_leaves as usize {
252            // If no leaves were provided
253            return None;
254        }
255
256        let Some(current_target_level) = current_target_level else {
257            // Index not found
258            return None;
259        };
260
261        let mut stack_bits = num_leaves;
262
263        {
264            let lowest_active_level = stack_bits.trailing_zeros() as usize;
265            // Reuse `stack[0]` for resulting value
266            // SAFETY: Active level must have been set successfully before, hence it exists
267            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
268            // Clear lowest active level
269            stack_bits &= !(1 << lowest_active_level);
270        }
271
272        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
273        // proof hashes
274        let mut merged_peaks = false;
275        loop {
276            let lowest_active_level = stack_bits.trailing_zeros() as usize;
277
278            if lowest_active_level == u64::BITS as usize {
279                break;
280            }
281
282            // Clear lowest active level for next iteration
283            stack_bits &= !(1 << lowest_active_level);
284
285            // SAFETY: Active level must have been set successfully before, hence it exists
286            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
287
288            if lowest_active_level > current_target_level
289                || (lowest_active_level == current_target_level
290                    && (position % 2 != 0)
291                    && !merged_peaks)
292            {
293                // SAFETY: Method signature guarantees upper bound of the proof length
294                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
295                proof_length += 1;
296                merged_peaks = false;
297            } else if lowest_active_level == current_target_level {
298                // SAFETY: Method signature guarantees upper bound of the proof length
299                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
300                proof_length += 1;
301                merged_peaks = false;
302            } else {
303                // Not collecting proof because of the need to merge peaks of an unbalanced tree
304                merged_peaks = true;
305            }
306
307            // Collect the lowest peak into the proof
308            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
309
310            position /= 2;
311        }
312
313        Some((stack[0], proof_length))
314    }
315
316    /// Verify a Merkle proof for a leaf at the given index
317    #[inline]
318    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
319    pub fn verify(
320        root: &[u8; OUT_LEN],
321        proof: &[[u8; OUT_LEN]],
322        leaf_index: u64,
323        leaf: [u8; OUT_LEN],
324        num_leaves: u64,
325    ) -> bool {
326        if leaf_index >= num_leaves {
327            return false;
328        }
329
330        let mut current = leaf;
331        let mut position = leaf_index;
332        let mut proof_pos = 0;
333        let mut level_size = num_leaves;
334
335        // Rebuild the path to the root
336        while level_size > 1 {
337            let is_left = position % 2 == 0;
338            let is_last = position == level_size - 1;
339
340            if is_left && !is_last {
341                // Left node with a right sibling
342                if proof_pos >= proof.len() {
343                    // Missing sibling
344                    return false;
345                }
346                current = hash_pair(&current, &proof[proof_pos]);
347                proof_pos += 1;
348            } else if !is_left {
349                // Right node with a left sibling
350                if proof_pos >= proof.len() {
351                    // Missing sibling
352                    return false;
353                }
354                current = hash_pair(&proof[proof_pos], &current);
355                proof_pos += 1;
356            } else {
357                // Last node, no sibling, keep current
358            }
359
360            position /= 2;
361            // Size of next level
362            level_size = level_size.div_ceil(2);
363        }
364
365        // Check if proof is fully used and matches root
366        proof_pos == proof.len() && current == *root
367    }
368}