ab_merkle_tree/
unbalanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedMerkleTree`] is
13/// more efficient and should be preferred when possible.
14///
15/// [`BalancedMerkleTree`]: crate::balanced::BalancedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    ///
41    /// Returns `None` for an empty list of leaves.
42    #[inline]
43    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44    pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45        leaves: Iter,
46    ) -> Option<[u8; OUT_LEN]>
47    where
48        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
49        Item: Into<[u8; OUT_LEN]>,
50        Iter: IntoIterator<Item = Item> + 'a,
51    {
52        // Stack of intermediate nodes per tree level
53        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
54        let mut num_leaves = 0_u64;
55
56        for hash in leaves {
57            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
58            // `>=` helps compiler with panic safety checks.
59            if num_leaves >= MAX_N {
60                return None;
61            }
62
63            let mut current = hash.into();
64
65            // Every bit set to `1` corresponds to an active Merkle Tree level
66            let lowest_active_levels = num_leaves.trailing_ones() as usize;
67            for item in stack.iter().take(lowest_active_levels) {
68                current = hash_pair(item, &current);
69            }
70
71            // Place the current hash at the first inactive level
72            stack[lowest_active_levels] = current;
73            num_leaves += 1;
74        }
75
76        if num_leaves == 0 {
77            // If no leaves were provided
78            return None;
79        }
80
81        let mut stack_bits = num_leaves;
82
83        {
84            let lowest_active_level = stack_bits.trailing_zeros() as usize;
85            // Reuse `stack[0]` for resulting value
86            // SAFETY: Active level must have been set successfully before, hence it exists
87            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
88            // Clear lowest active level
89            stack_bits &= !(1 << lowest_active_level);
90        }
91
92        // Hash remaining peaks (if any) of the potentially unbalanced tree together
93        loop {
94            let lowest_active_level = stack_bits.trailing_zeros() as usize;
95
96            if lowest_active_level == u64::BITS as usize {
97                break;
98            }
99
100            // Clear lowest active level for next iteration
101            stack_bits &= !(1 << lowest_active_level);
102
103            // SAFETY: Active level must have been set successfully before, hence it exists
104            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
105
106            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
107        }
108
109        Some(stack[0])
110    }
111
112    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
113    ///
114    /// Returns `Some((root, proof))` on success, `None` if index is outside of list of leaves.
115    ///
116    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
117    /// usage.
118    #[inline]
119    #[cfg(feature = "alloc")]
120    pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
121        leaves: Iter,
122        target_index: usize,
123    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
124    where
125        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
126        Item: Into<[u8; OUT_LEN]>,
127        Iter: IntoIterator<Item = Item> + 'a,
128    {
129        // Stack of intermediate nodes per tree level
130        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
131        // SAFETY: Inner value is `MaybeUninit`
132        let mut proof = unsafe {
133            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize]>::new_uninit().assume_init()
134        };
135
136        let (root, proof_length) =
137            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
138
139        let proof_capacity = proof.len();
140        let proof = Box::into_raw(proof);
141        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
142        // initialized
143        let proof = unsafe {
144            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
145        };
146
147        Some((root, proof))
148    }
149
150    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
151    ///
152    /// Returns `Some((root, proof))` on success, `None` if index is outside of list of leaves.
153    ///
154    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
155    /// usage.
156    #[inline]
157    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
158    pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
159        leaves: Iter,
160        target_index: usize,
161        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
162    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
163    where
164        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
165        Item: Into<[u8; OUT_LEN]>,
166        Iter: IntoIterator<Item = Item> + 'a,
167    {
168        // Stack of intermediate nodes per tree level
169        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
170
171        let (root, proof_length) =
172            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
173        // SAFETY: Just correctly initialized `proof_length` elements
174        let proof = unsafe { proof.get_unchecked_mut(..proof_length).assume_init_mut() };
175
176        Some((root, proof))
177    }
178
179    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
180    fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
181        leaves: Iter,
182        target_index: usize,
183        stack: &mut [[u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1],
184        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
185    ) -> Option<([u8; OUT_LEN], usize)>
186    where
187        Item: Into<[u8; OUT_LEN]>,
188        Iter: IntoIterator<Item = Item> + 'a,
189    {
190        let mut proof_length = 0;
191        let mut num_leaves = 0_u64;
192
193        let mut current_target_level = None;
194        let mut position = target_index;
195
196        for (current_index, hash) in leaves.into_iter().enumerate() {
197            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
198            // `>=` helps compiler with panic safety checks.
199            if num_leaves >= MAX_N {
200                return None;
201            }
202
203            let mut current = hash.into();
204
205            // Every bit set to `1` corresponds to an active Merkle Tree level
206            let lowest_active_levels = num_leaves.trailing_ones() as usize;
207
208            if current_index == target_index {
209                for item in stack.iter().take(lowest_active_levels) {
210                    // If at the target leaf index, need to collect the proof
211                    // SAFETY: Method signature guarantees upper bound of the proof length
212                    unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
213                    proof_length += 1;
214
215                    current = hash_pair(item, &current);
216
217                    // Move up the tree
218                    position /= 2;
219                }
220
221                current_target_level = Some(lowest_active_levels);
222            } else {
223                for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
224                    if current_target_level == Some(level) {
225                        // SAFETY: Method signature guarantees upper bound of the proof length
226                        unsafe { proof.get_unchecked_mut(proof_length) }.write(
227                            if position.is_multiple_of(2) {
228                                current
229                            } else {
230                                *item
231                            },
232                        );
233                        proof_length += 1;
234
235                        current_target_level = Some(level + 1);
236
237                        // Move up the tree
238                        position /= 2;
239                    }
240
241                    current = hash_pair(item, &current);
242                }
243            }
244
245            // Place the current hash at the first inactive level
246            stack[lowest_active_levels] = current;
247            num_leaves += 1;
248        }
249
250        // `active_levels` here contains the number of leaves after above loop
251        if target_index >= num_leaves as usize {
252            // If no leaves were provided
253            return None;
254        }
255
256        let Some(current_target_level) = current_target_level else {
257            // Index not found
258            return None;
259        };
260
261        let mut stack_bits = num_leaves;
262
263        {
264            let lowest_active_level = stack_bits.trailing_zeros() as usize;
265            // Reuse `stack[0]` for resulting value
266            // SAFETY: Active level must have been set successfully before, hence it exists
267            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
268            // Clear lowest active level
269            stack_bits &= !(1 << lowest_active_level);
270        }
271
272        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
273        // proof hashes
274        let mut merged_peaks = false;
275        loop {
276            let lowest_active_level = stack_bits.trailing_zeros() as usize;
277
278            if lowest_active_level == u64::BITS as usize {
279                break;
280            }
281
282            // Clear lowest active level for next iteration
283            stack_bits &= !(1 << lowest_active_level);
284
285            // SAFETY: Active level must have been set successfully before, hence it exists
286            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
287
288            if lowest_active_level > current_target_level
289                || (lowest_active_level == current_target_level
290                    && !position.is_multiple_of(2)
291                    && !merged_peaks)
292            {
293                // SAFETY: Method signature guarantees upper bound of the proof length
294                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
295                proof_length += 1;
296                merged_peaks = false;
297            } else if lowest_active_level == current_target_level {
298                // SAFETY: Method signature guarantees upper bound of the proof length
299                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
300                proof_length += 1;
301                merged_peaks = false;
302            } else {
303                // Not collecting proof because of the need to merge peaks of an unbalanced tree
304                merged_peaks = true;
305            }
306
307            // Collect the lowest peak into the proof
308            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
309
310            position /= 2;
311        }
312
313        Some((stack[0], proof_length))
314    }
315
316    /// Verify a Merkle proof for a leaf at the given index
317    #[inline]
318    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
319    pub fn verify(
320        root: &[u8; OUT_LEN],
321        proof: &[[u8; OUT_LEN]],
322        leaf_index: u64,
323        leaf: [u8; OUT_LEN],
324        num_leaves: u64,
325    ) -> bool {
326        if leaf_index >= num_leaves {
327            return false;
328        }
329
330        let mut current = leaf;
331        let mut position = leaf_index;
332        let mut proof_pos = 0;
333        let mut level_size = num_leaves;
334
335        // Rebuild the path to the root
336        while level_size > 1 {
337            let is_left = position.is_multiple_of(2);
338            let is_last = position == level_size - 1;
339
340            if is_left && !is_last {
341                // Left node with a right sibling
342                if proof_pos >= proof.len() {
343                    // Missing sibling
344                    return false;
345                }
346                current = hash_pair(&current, &proof[proof_pos]);
347                proof_pos += 1;
348            } else if !is_left {
349                // Right node with a left sibling
350                if proof_pos >= proof.len() {
351                    // Missing sibling
352                    return false;
353                }
354                current = hash_pair(&proof[proof_pos], &current);
355                proof_pos += 1;
356            } else {
357                // Last node, no sibling, keep current
358            }
359
360            position /= 2;
361            // Size of next level
362            level_size = level_size.div_ceil(2);
363        }
364
365        // Check if proof is fully used and matches root
366        proof_pos == proof.len() && current == *root
367    }
368}