ab_merkle_tree/
unbalanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedMerkleTree`] is
13/// more efficient and should be preferred when possible.
14///
15/// [`BalancedMerkleTree`]: crate::balanced::BalancedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    ///
41    /// Returns `None` for an empty list of leaves.
42    #[inline]
43    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44    pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45        leaves: Iter,
46    ) -> Option<[u8; OUT_LEN]>
47    where
48        [(); MAX_N.ilog2() as usize + 1]:,
49        Item: Into<[u8; OUT_LEN]>,
50        Iter: IntoIterator<Item = Item> + 'a,
51    {
52        // Stack of intermediate nodes per tree level
53        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
54        let mut num_leaves = 0_u64;
55
56        for hash in leaves {
57            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
58            // `>=` helps compiler with panic safety checks.
59            if num_leaves >= MAX_N {
60                return None;
61            }
62
63            let mut current = hash.into();
64
65            // Every bit set to `1` corresponds to an active Merkle Tree level
66            let lowest_active_levels = num_leaves.trailing_ones() as usize;
67            for item in stack.iter().take(lowest_active_levels) {
68                current = hash_pair(item, &current);
69            }
70
71            // Place the current hash at the first inactive level
72            stack[lowest_active_levels] = current;
73            num_leaves += 1;
74        }
75
76        if num_leaves == 0 {
77            // If no leaves were provided
78            return None;
79        }
80
81        let mut stack_bits = num_leaves;
82
83        {
84            let lowest_active_level = stack_bits.trailing_zeros() as usize;
85            // Reuse `stack[0]` for resulting value
86            // SAFETY: Active level must have been set successfully before, hence it exists
87            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
88            // Clear lowest active level
89            stack_bits &= !(1 << lowest_active_level);
90        }
91
92        // Hash remaining peaks (if any) of the potentially unbalanced tree together
93        loop {
94            let lowest_active_level = stack_bits.trailing_zeros() as usize;
95
96            if lowest_active_level == u64::BITS as usize {
97                break;
98            }
99
100            // Clear lowest active level for next iteration
101            stack_bits &= !(1 << lowest_active_level);
102
103            // SAFETY: Active level must have been set successfully before, hence it exists
104            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
105
106            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
107        }
108
109        Some(stack[0])
110    }
111
112    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
113    ///
114    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
115    ///
116    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
117    /// usage.
118    #[inline]
119    #[cfg(feature = "alloc")]
120    pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
121        leaves: Iter,
122        target_index: usize,
123    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
124    where
125        [(); MAX_N.ilog2() as usize + 1]:,
126        Item: Into<[u8; OUT_LEN]>,
127        Iter: IntoIterator<Item = Item> + 'a,
128    {
129        // Stack of intermediate nodes per tree level
130        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
131        // SAFETY: Inner value is `MaybeUninit`
132        let mut proof = unsafe {
133            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
134                .assume_init()
135        };
136
137        let (root, proof_length) =
138            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
139
140        let proof_capacity = proof.len();
141        let proof = Box::into_raw(proof);
142        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
143        // initialized
144        let proof = unsafe {
145            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
146        };
147
148        Some((root, proof))
149    }
150
151    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
152    ///
153    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
154    ///
155    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
156    /// usage.
157    #[inline]
158    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
159    pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
160        leaves: Iter,
161        target_index: usize,
162        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
163    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
164    where
165        [(); MAX_N.ilog2() as usize + 1]:,
166        Item: Into<[u8; OUT_LEN]>,
167        Iter: IntoIterator<Item = Item> + 'a,
168    {
169        // Stack of intermediate nodes per tree level
170        let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
171
172        let (root, proof_length) =
173            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
174        // SAFETY: Just correctly initialized `proof_length` elements
175        let proof = unsafe {
176            proof
177                .split_at_mut_unchecked(proof_length)
178                .0
179                .assume_init_mut()
180        };
181
182        Some((root, proof))
183    }
184
185    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
186    fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
187        leaves: Iter,
188        target_index: usize,
189        stack: &mut [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
190        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
191    ) -> Option<([u8; OUT_LEN], usize)>
192    where
193        [(); MAX_N.ilog2() as usize + 1]:,
194        Item: Into<[u8; OUT_LEN]>,
195        Iter: IntoIterator<Item = Item> + 'a,
196    {
197        let mut proof_length = 0;
198        let mut num_leaves = 0_u64;
199
200        let mut current_target_level = None;
201        let mut position = target_index;
202
203        for (current_index, hash) in leaves.into_iter().enumerate() {
204            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
205            // `>=` helps compiler with panic safety checks.
206            if num_leaves >= MAX_N {
207                return None;
208            }
209
210            let mut current = hash.into();
211
212            // Every bit set to `1` corresponds to an active Merkle Tree level
213            let lowest_active_levels = num_leaves.trailing_ones() as usize;
214
215            if current_index == target_index {
216                for item in stack.iter().take(lowest_active_levels) {
217                    // If at the target leaf index, need to collect the proof
218                    // SAFETY: Method signature guarantees upper bound of the proof length
219                    unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
220                    proof_length += 1;
221
222                    current = hash_pair(item, &current);
223
224                    // Move up the tree
225                    position /= 2;
226                }
227
228                current_target_level = Some(lowest_active_levels);
229            } else {
230                for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
231                    if current_target_level == Some(level) {
232                        // SAFETY: Method signature guarantees upper bound of the proof length
233                        unsafe { proof.get_unchecked_mut(proof_length) }.write(
234                            if position.is_multiple_of(2) {
235                                current
236                            } else {
237                                *item
238                            },
239                        );
240                        proof_length += 1;
241
242                        current_target_level = Some(level + 1);
243
244                        // Move up the tree
245                        position /= 2;
246                    }
247
248                    current = hash_pair(item, &current);
249                }
250            }
251
252            // Place the current hash at the first inactive level
253            stack[lowest_active_levels] = current;
254            num_leaves += 1;
255        }
256
257        // `active_levels` here contains the number of leaves after above loop
258        if target_index >= num_leaves as usize {
259            // If no leaves were provided
260            return None;
261        }
262
263        let Some(current_target_level) = current_target_level else {
264            // Index not found
265            return None;
266        };
267
268        let mut stack_bits = num_leaves;
269
270        {
271            let lowest_active_level = stack_bits.trailing_zeros() as usize;
272            // Reuse `stack[0]` for resulting value
273            // SAFETY: Active level must have been set successfully before, hence it exists
274            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
275            // Clear lowest active level
276            stack_bits &= !(1 << lowest_active_level);
277        }
278
279        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
280        // proof hashes
281        let mut merged_peaks = false;
282        loop {
283            let lowest_active_level = stack_bits.trailing_zeros() as usize;
284
285            if lowest_active_level == u64::BITS as usize {
286                break;
287            }
288
289            // Clear lowest active level for next iteration
290            stack_bits &= !(1 << lowest_active_level);
291
292            // SAFETY: Active level must have been set successfully before, hence it exists
293            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
294
295            if lowest_active_level > current_target_level
296                || (lowest_active_level == current_target_level
297                    && !position.is_multiple_of(2)
298                    && !merged_peaks)
299            {
300                // SAFETY: Method signature guarantees upper bound of the proof length
301                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
302                proof_length += 1;
303                merged_peaks = false;
304            } else if lowest_active_level == current_target_level {
305                // SAFETY: Method signature guarantees upper bound of the proof length
306                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
307                proof_length += 1;
308                merged_peaks = false;
309            } else {
310                // Not collecting proof because of the need to merge peaks of an unbalanced tree
311                merged_peaks = true;
312            }
313
314            // Collect the lowest peak into the proof
315            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
316
317            position /= 2;
318        }
319
320        Some((stack[0], proof_length))
321    }
322
323    /// Verify a Merkle proof for a leaf at the given index
324    #[inline]
325    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
326    pub fn verify(
327        root: &[u8; OUT_LEN],
328        proof: &[[u8; OUT_LEN]],
329        leaf_index: u64,
330        leaf: [u8; OUT_LEN],
331        num_leaves: u64,
332    ) -> bool {
333        if leaf_index >= num_leaves {
334            return false;
335        }
336
337        let mut current = leaf;
338        let mut position = leaf_index;
339        let mut proof_pos = 0;
340        let mut level_size = num_leaves;
341
342        // Rebuild the path to the root
343        while level_size > 1 {
344            let is_left = position.is_multiple_of(2);
345            let is_last = position == level_size - 1;
346
347            if is_left && !is_last {
348                // Left node with a right sibling
349                if proof_pos >= proof.len() {
350                    // Missing sibling
351                    return false;
352                }
353                current = hash_pair(&current, &proof[proof_pos]);
354                proof_pos += 1;
355            } else if !is_left {
356                // Right node with a left sibling
357                if proof_pos >= proof.len() {
358                    // Missing sibling
359                    return false;
360                }
361                current = hash_pair(&proof[proof_pos], &current);
362                proof_pos += 1;
363            } else {
364                // Last node, no sibling, keep current
365            }
366
367            position /= 2;
368            // Size of next level
369            level_size = level_size.div_ceil(2);
370        }
371
372        // Check if proof is fully used and matches root
373        proof_pos == proof.len() && current == *root
374    }
375}