ab_merkle_tree/
unbalanced_hashed.rs

1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6use blake3::OUT_LEN;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedHashedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedHashedMerkleTree`]
13/// is more efficient and should be preferred when possible.
14///
15/// [`BalancedHashedMerkleTree`]: crate::balanced_hashed::BalancedHashedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```ignore
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedHashedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedHashedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    #[inline]
41    pub fn compute_root_only<'a, const N: usize, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
42    where
43        [(); N.ilog2() as usize + 1]:,
44        Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
45    {
46        // Stack of intermediate nodes per tree level
47        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
48        // Bitmask: bit `i = 1` if level `i` is active
49        let mut active_levels = 0_u64;
50
51        for &hash in leaves {
52            let mut current = hash;
53            let mut level = 0;
54
55            // Check if level is active by testing bit (active_levels & (1 << level))
56            while (active_levels & (1 << level)) != 0 {
57                current = hash_pair(&stack[level], &current);
58
59                // Clear the current level
60                active_levels &= !(1 << level);
61                level += 1;
62            }
63
64            // Place the current hash at the first inactive level
65            stack[level] = current;
66            // Set bit for level
67            active_levels |= 1 << level;
68        }
69
70        if active_levels == 0 {
71            // If no leaves were provided
72            return None;
73        }
74
75        {
76            let lowest_active_level = active_levels.trailing_zeros() as usize;
77            // Reuse `stack[0]` for resulting value
78            stack[0] = stack[lowest_active_level];
79            // Clear lowest active level
80            active_levels &= !(1 << lowest_active_level);
81        }
82
83        // Hash remaining peaks (if any) of the potentially unbalanced tree together
84        loop {
85            let lowest_active_level = active_levels.trailing_zeros() as usize;
86
87            if lowest_active_level == u64::BITS as usize {
88                break;
89            }
90
91            // Clear lowest active level
92            active_levels &= !(1 << lowest_active_level);
93
94            stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
95        }
96
97        Some(stack[0])
98    }
99
100    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
101    ///
102    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
103    ///
104    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
105    /// usage.
106    #[inline]
107    #[cfg(feature = "alloc")]
108    pub fn compute_root_and_proof<'a, const N: usize, Iter>(
109        leaves: Iter,
110        target_index: usize,
111    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
112    where
113        [(); N.ilog2() as usize + 1]:,
114        Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
115    {
116        // Stack of intermediate nodes per tree level
117        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
118        // SAFETY: Inner value is `MaybeUninit`
119        let mut proof = unsafe {
120            Box::<[MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1]>::new_uninit().assume_init()
121        };
122
123        let (root, proof_length) =
124            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
125
126        let proof_capacity = proof.len();
127        let proof = Box::into_raw(proof);
128        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
129        // initialized
130        let proof = unsafe {
131            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
132        };
133
134        Some((root, proof))
135    }
136
137    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
138    ///
139    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
140    ///
141    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
142    /// usage.
143    #[inline]
144    pub fn compute_root_and_proof_in<'a, 'proof, const N: usize, Iter>(
145        leaves: Iter,
146        target_index: usize,
147        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
148    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
149    where
150        [(); N.ilog2() as usize + 1]:,
151        Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
152    {
153        // Stack of intermediate nodes per tree level
154        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
155
156        let (root, proof_length) =
157            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
158        // SAFETY: Just correctly initialized `proof_length` elements
159        let proof = unsafe { proof[..proof_length].assume_init_mut() };
160
161        Some((root, proof))
162    }
163
164    fn compute_root_and_proof_inner<'a, const N: usize, Iter>(
165        leaves: Iter,
166        target_index: usize,
167        stack: &mut [[u8; OUT_LEN]; N.ilog2() as usize + 1],
168        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
169    ) -> Option<([u8; OUT_LEN], usize)>
170    where
171        [(); N.ilog2() as usize + 1]:,
172        Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
173    {
174        let mut proof_length = 0;
175        let mut active_levels = 0_u64;
176
177        let mut current_target_level = None;
178        let mut position = target_index;
179
180        for (current_index, &hash) in leaves.enumerate() {
181            let mut current = hash;
182            let mut level = 0;
183
184            if current_index == target_index {
185                // Check if level is active by testing bit (active_levels & (1 << level))
186                while (active_levels & (1 << level)) != 0 {
187                    // If at the target leaf index, need to collect the proof
188                    // SAFETY: Method signature guarantees upper bound of the proof length
189                    unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[level]);
190                    proof_length += 1;
191
192                    current = hash_pair(&stack[level], &current);
193
194                    // Clear the current level
195                    active_levels &= !(1 << level);
196                    level += 1;
197
198                    // Move up the tree
199                    position /= 2;
200                }
201
202                current_target_level = Some(level);
203
204                // Place the current hash at the first inactive level
205                stack[level] = current;
206                // Set bit for level
207                active_levels |= 1 << level;
208            } else {
209                // If at the target leaf index, need to collect the proof
210                while (active_levels & (1 << level)) != 0 {
211                    if current_target_level == Some(level) {
212                        // SAFETY: Method signature guarantees upper bound of the proof length
213                        unsafe { proof.get_unchecked_mut(proof_length) }.write(
214                            if position % 2 == 0 {
215                                current
216                            } else {
217                                stack[level]
218                            },
219                        );
220                        proof_length += 1;
221
222                        current_target_level = Some(level + 1);
223
224                        // Move up the tree
225                        position /= 2;
226                    }
227
228                    current = hash_pair(&stack[level], &current);
229
230                    // Clear the current level
231                    active_levels &= !(1 << level);
232                    level += 1;
233                }
234
235                // Place the current hash at the first inactive level
236                stack[level] = current;
237                // Set bit for level
238                active_levels |= 1 << level;
239            }
240        }
241
242        // `active_levels` here contains the number of leaves after above loop
243        if target_index >= active_levels as usize {
244            // If no leaves were provided
245            return None;
246        }
247
248        let Some(current_target_level) = current_target_level else {
249            // Index not found
250            return None;
251        };
252
253        {
254            let lowest_active_level = active_levels.trailing_zeros() as usize;
255            // Reuse `stack[0]` for resulting value
256            stack[0] = stack[lowest_active_level];
257            // Clear lowest active level
258            active_levels &= !(1 << lowest_active_level);
259        }
260
261        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
262        // proof hashes
263        let mut merged_peaks = false;
264        loop {
265            let lowest_active_level = active_levels.trailing_zeros() as usize;
266
267            if lowest_active_level == u64::BITS as usize {
268                break;
269            }
270
271            // Clear lowest active level
272            active_levels &= !(1 << lowest_active_level);
273
274            if lowest_active_level > current_target_level
275                || (lowest_active_level == current_target_level
276                    && (position % 2 != 0)
277                    && !merged_peaks)
278            {
279                // SAFETY: Method signature guarantees upper bound of the proof length
280                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[lowest_active_level]);
281                proof_length += 1;
282                merged_peaks = false;
283            } else if lowest_active_level == current_target_level {
284                // SAFETY: Method signature guarantees upper bound of the proof length
285                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
286                proof_length += 1;
287                merged_peaks = false;
288            } else {
289                // Not collecting proof because of the need to merge peaks of an unbalanced tree
290                merged_peaks = true;
291            }
292
293            // Collect the lowest peak into the proof
294            stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
295
296            position /= 2;
297        }
298
299        Some((stack[0], proof_length))
300    }
301
302    /// Verify a Merkle proof for a leaf at the given index
303    #[inline]
304    // TODO: Make `num_leaves` optional in case the leaf is trusted (like just hashed from another
305    //  value and guaranteed not to use the same keyed hash as used here)
306    pub fn verify(
307        root: &[u8; OUT_LEN],
308        proof: &[[u8; OUT_LEN]],
309        leaf_index: usize,
310        leaf: [u8; OUT_LEN],
311        num_leaves: usize,
312    ) -> bool {
313        if leaf_index >= num_leaves {
314            return false;
315        }
316
317        let mut current = leaf;
318        let mut position = leaf_index;
319        let mut proof_pos = 0;
320        let mut level_size = num_leaves;
321
322        // Rebuild the path to the root
323        while level_size > 1 {
324            let is_left = position % 2 == 0;
325            let is_last = position == level_size - 1;
326
327            if is_left && !is_last {
328                // Left node with a right sibling
329                if proof_pos >= proof.len() {
330                    // Missing sibling
331                    return false;
332                }
333                current = hash_pair(&current, &proof[proof_pos]);
334                proof_pos += 1;
335            } else if !is_left {
336                // Right node with a left sibling
337                if proof_pos >= proof.len() {
338                    // Missing sibling
339                    return false;
340                }
341                current = hash_pair(&proof[proof_pos], &current);
342                proof_pos += 1;
343            } else {
344                // Last node, no sibling, keep current
345            }
346
347            position /= 2;
348            // Size of next level
349            level_size = level_size.div_ceil(2);
350        }
351
352        // Check if proof is fully used and matches root
353        proof_pos == proof.len() && current == *root
354    }
355}