Skip to main content

ab_merkle_tree/
unbalanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedMerkleTree`] is
13/// more efficient and should be preferred when possible.
14///
15/// [`BalancedMerkleTree`]: crate::balanced::BalancedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    ///
41    /// Returns `None` for an empty list of leaves, or if the number of leaves is larger than
42    /// `MAX_N`.
43    #[inline]
44    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
45    pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
46        leaves: Iter,
47    ) -> Option<[u8; OUT_LEN]>
48    where
49        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
50        Item: Into<[u8; OUT_LEN]>,
51        Iter: IntoIterator<Item = Item> + 'a,
52    {
53        // Stack of intermediate nodes per tree level
54        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
55        let mut num_leaves = 0_u64;
56
57        for hash in leaves {
58            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
59            // `>=` helps compiler with panic safety checks.
60            if num_leaves >= MAX_N {
61                return None;
62            }
63
64            let mut current = hash.into();
65
66            // Every bit set to `1` corresponds to an active Merkle Tree level
67            let lowest_active_levels = num_leaves.trailing_ones() as usize;
68            for item in stack.iter().take(lowest_active_levels) {
69                current = hash_pair(item, &current);
70            }
71
72            // Place the current hash at the first inactive level
73            stack[lowest_active_levels] = current;
74            num_leaves += 1;
75        }
76
77        if num_leaves == 0 {
78            // If no leaves were provided
79            return None;
80        }
81
82        let mut stack_bits = num_leaves;
83
84        {
85            let lowest_active_level = stack_bits.trailing_zeros() as usize;
86            // Reuse `stack[0]` for resulting value
87            // SAFETY: Active level must have been set successfully before, hence it exists
88            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
89            // Clear lowest active level
90            stack_bits &= !(1 << lowest_active_level);
91        }
92
93        // Hash remaining peaks (if any) of the potentially unbalanced tree together
94        loop {
95            let lowest_active_level = stack_bits.trailing_zeros() as usize;
96
97            if lowest_active_level == u64::BITS as usize {
98                break;
99            }
100
101            // Clear lowest active level for next iteration
102            stack_bits &= !(1 << lowest_active_level);
103
104            // SAFETY: Active level must have been set successfully before, hence it exists
105            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
106
107            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
108        }
109
110        Some(stack[0])
111    }
112
113    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
114    ///
115    /// Returns `Some((root, proof))` on success, `None` if index is outside of list of leaves.
116    ///
117    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
118    /// usage.
119    #[inline]
120    #[cfg(feature = "alloc")]
121    pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
122        leaves: Iter,
123        target_index: usize,
124    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
125    where
126        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
127        Item: Into<[u8; OUT_LEN]>,
128        Iter: IntoIterator<Item = Item> + 'a,
129    {
130        // Stack of intermediate nodes per tree level
131        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
132        // SAFETY: Inner value is `MaybeUninit`
133        let mut proof = unsafe {
134            Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize]>::new_uninit().assume_init()
135        };
136
137        let (root, proof_length) =
138            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
139
140        let proof_capacity = proof.len();
141        let proof = Box::into_raw(proof);
142        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
143        // initialized
144        let proof = unsafe {
145            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
146        };
147
148        Some((root, proof))
149    }
150
151    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
152    ///
153    /// Returns `Some((root, proof))` on success, `None` if index is outside of list of leaves.
154    ///
155    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
156    /// usage.
157    #[inline]
158    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
159    pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
160        leaves: Iter,
161        target_index: usize,
162        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
163    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
164    where
165        [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
166        Item: Into<[u8; OUT_LEN]>,
167        Iter: IntoIterator<Item = Item> + 'a,
168    {
169        // Stack of intermediate nodes per tree level
170        let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
171
172        let (root, proof_length) =
173            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
174        // SAFETY: Just correctly initialized `proof_length` elements
175        let proof = unsafe { proof.get_unchecked_mut(..proof_length).assume_init_mut() };
176
177        Some((root, proof))
178    }
179
180    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
181    fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
182        leaves: Iter,
183        target_index: usize,
184        stack: &mut [[u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1],
185        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
186    ) -> Option<([u8; OUT_LEN], usize)>
187    where
188        Item: Into<[u8; OUT_LEN]>,
189        Iter: IntoIterator<Item = Item> + 'a,
190    {
191        let mut proof_length = 0;
192        let mut num_leaves = 0_u64;
193
194        let mut current_target_level = None;
195        let mut position = target_index;
196
197        for (current_index, hash) in leaves.into_iter().enumerate() {
198            // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
199            // `>=` helps compiler with panic safety checks.
200            if num_leaves >= MAX_N {
201                return None;
202            }
203
204            let mut current = hash.into();
205
206            // Every bit set to `1` corresponds to an active Merkle Tree level
207            let lowest_active_levels = num_leaves.trailing_ones() as usize;
208
209            if current_index == target_index {
210                for item in stack.iter().take(lowest_active_levels) {
211                    // If at the target leaf index, need to collect the proof
212                    // SAFETY: Method signature guarantees upper bound of the proof length
213                    unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
214                    proof_length += 1;
215
216                    current = hash_pair(item, &current);
217
218                    // Move up the tree
219                    position /= 2;
220                }
221
222                current_target_level = Some(lowest_active_levels);
223            } else {
224                for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
225                    if current_target_level == Some(level) {
226                        // SAFETY: Method signature guarantees upper bound of the proof length
227                        unsafe { proof.get_unchecked_mut(proof_length) }.write(
228                            if position.is_multiple_of(2) {
229                                current
230                            } else {
231                                *item
232                            },
233                        );
234                        proof_length += 1;
235
236                        current_target_level = Some(level + 1);
237
238                        // Move up the tree
239                        position /= 2;
240                    }
241
242                    current = hash_pair(item, &current);
243                }
244            }
245
246            // Place the current hash at the first inactive level
247            stack[lowest_active_levels] = current;
248            num_leaves += 1;
249        }
250
251        // `active_levels` here contains the number of leaves after above loop
252        if target_index >= num_leaves as usize {
253            // If no leaves were provided
254            return None;
255        }
256
257        let Some(current_target_level) = current_target_level else {
258            // Index not found
259            return None;
260        };
261
262        let mut stack_bits = num_leaves;
263
264        {
265            let lowest_active_level = stack_bits.trailing_zeros() as usize;
266            // Reuse `stack[0]` for resulting value
267            // SAFETY: Active level must have been set successfully before, hence it exists
268            stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
269            // Clear lowest active level
270            stack_bits &= !(1 << lowest_active_level);
271        }
272
273        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
274        // proof hashes
275        let mut merged_peaks = false;
276        loop {
277            let lowest_active_level = stack_bits.trailing_zeros() as usize;
278
279            if lowest_active_level == u64::BITS as usize {
280                break;
281            }
282
283            // Clear lowest active level for next iteration
284            stack_bits &= !(1 << lowest_active_level);
285
286            // SAFETY: Active level must have been set successfully before, hence it exists
287            let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
288
289            if lowest_active_level > current_target_level
290                || (lowest_active_level == current_target_level
291                    && !position.is_multiple_of(2)
292                    && !merged_peaks)
293            {
294                // SAFETY: Method signature guarantees upper bound of the proof length
295                unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
296                proof_length += 1;
297                merged_peaks = false;
298            } else if lowest_active_level == current_target_level {
299                // SAFETY: Method signature guarantees upper bound of the proof length
300                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
301                proof_length += 1;
302                merged_peaks = false;
303            } else {
304                // Not collecting proof because of the need to merge peaks of an unbalanced tree
305                merged_peaks = true;
306            }
307
308            // Collect the lowest peak into the proof
309            stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
310
311            position /= 2;
312        }
313
314        Some((stack[0], proof_length))
315    }
316
317    /// Verify a Merkle proof for a leaf at the given index
318    #[inline]
319    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
320    pub fn verify(
321        root: &[u8; OUT_LEN],
322        proof: &[[u8; OUT_LEN]],
323        leaf_index: u64,
324        leaf: [u8; OUT_LEN],
325        num_leaves: u64,
326    ) -> bool {
327        if leaf_index >= num_leaves {
328            return false;
329        }
330
331        let mut current = leaf;
332        let mut position = leaf_index;
333        let mut proof_pos = 0;
334        let mut level_size = num_leaves;
335
336        // Rebuild the path to the root
337        while level_size > 1 {
338            let is_left = position.is_multiple_of(2);
339            let is_last = position == level_size - 1;
340
341            if is_left && !is_last {
342                // Left node with a right sibling
343                if proof_pos >= proof.len() {
344                    // Missing sibling
345                    return false;
346                }
347                current = hash_pair(&current, &proof[proof_pos]);
348                proof_pos += 1;
349            } else if !is_left {
350                // Right node with a left sibling
351                if proof_pos >= proof.len() {
352                    // Missing sibling
353                    return false;
354                }
355                current = hash_pair(&proof[proof_pos], &current);
356                proof_pos += 1;
357            } else {
358                // Last node, no sibling, keep current
359            }
360
361            position /= 2;
362            // Size of next level
363            level_size = level_size.div_ceil(2);
364        }
365
366        // Check if proof is fully used and matches root
367        proof_pos == proof.len() && current == *root
368    }
369}