ab_merkle_tree/
unbalanced_hashed.rs

1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6use blake3::OUT_LEN;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedHashedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedHashedMerkleTree`]
13/// is more efficient and should be preferred when possible.
14///
15/// [`BalancedHashedMerkleTree`]: crate::balanced_hashed::BalancedHashedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19///               Root
20///         /--------------\
21///        H3              H4
22///    /-------\         /----\
23///   H0       H1       H2     \
24///  /  \     /  \     /  \     \
25/// L0  L1   L2  L3   L4  L5    L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedHashedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31//  https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33//  of power-of-two elements that can be processed in parallel and do it recursively until a single
34//  element is left. This can be done for both root creation and proof generation.
35impl UnbalancedHashedMerkleTree {
36    /// Compute Merkle Tree Root.
37    ///
38    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39    /// usage.
40    ///
41    /// Returns `None` for an empty list of leaves.
42    #[inline]
43    pub fn compute_root_only<'a, const N: usize, Item, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
44    where
45        [(); N.ilog2() as usize + 1]:,
46        Item: Into<[u8; OUT_LEN]>,
47        Iter: IntoIterator<Item = Item> + 'a,
48    {
49        // Stack of intermediate nodes per tree level
50        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
51        // Bitmask: bit `i = 1` if level `i` is active
52        let mut active_levels = 0_u64;
53
54        for hash in leaves {
55            let mut current = hash.into();
56            let mut level = 0;
57
58            // Check if level is active by testing bit (active_levels & (1 << level))
59            while (active_levels & (1 << level)) != 0 {
60                current = hash_pair(&stack[level], &current);
61
62                // Clear the current level
63                active_levels &= !(1 << level);
64                level += 1;
65            }
66
67            // Place the current hash at the first inactive level
68            stack[level] = current;
69            // Set bit for level
70            active_levels |= 1 << level;
71        }
72
73        if active_levels == 0 {
74            // If no leaves were provided
75            return None;
76        }
77
78        {
79            let lowest_active_level = active_levels.trailing_zeros() as usize;
80            // Reuse `stack[0]` for resulting value
81            stack[0] = stack[lowest_active_level];
82            // Clear lowest active level
83            active_levels &= !(1 << lowest_active_level);
84        }
85
86        // Hash remaining peaks (if any) of the potentially unbalanced tree together
87        loop {
88            let lowest_active_level = active_levels.trailing_zeros() as usize;
89
90            if lowest_active_level == u64::BITS as usize {
91                break;
92            }
93
94            // Clear lowest active level
95            active_levels &= !(1 << lowest_active_level);
96
97            stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
98        }
99
100        Some(stack[0])
101    }
102
103    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
104    ///
105    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
106    ///
107    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
108    /// usage.
109    #[inline]
110    #[cfg(feature = "alloc")]
111    pub fn compute_root_and_proof<'a, const N: usize, Item, Iter>(
112        leaves: Iter,
113        target_index: usize,
114    ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
115    where
116        [(); N.ilog2() as usize + 1]:,
117        Item: Into<[u8; OUT_LEN]>,
118        Iter: IntoIterator<Item = Item> + 'a,
119    {
120        // Stack of intermediate nodes per tree level
121        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
122        // SAFETY: Inner value is `MaybeUninit`
123        let mut proof = unsafe {
124            Box::<[MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1]>::new_uninit().assume_init()
125        };
126
127        let (root, proof_length) =
128            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
129
130        let proof_capacity = proof.len();
131        let proof = Box::into_raw(proof);
132        // SAFETY: Points to correctly allocated memory where `proof_length` elements were
133        // initialized
134        let proof = unsafe {
135            Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
136        };
137
138        Some((root, proof))
139    }
140
141    /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
142    ///
143    /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
144    ///
145    /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
146    /// usage.
147    #[inline]
148    pub fn compute_root_and_proof_in<'a, 'proof, const N: usize, Item, Iter>(
149        leaves: Iter,
150        target_index: usize,
151        proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
152    ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
153    where
154        [(); N.ilog2() as usize + 1]:,
155        Item: Into<[u8; OUT_LEN]>,
156        Iter: IntoIterator<Item = Item> + 'a,
157    {
158        // Stack of intermediate nodes per tree level
159        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
160
161        let (root, proof_length) =
162            Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
163        // SAFETY: Just correctly initialized `proof_length` elements
164        let proof = unsafe { proof[..proof_length].assume_init_mut() };
165
166        Some((root, proof))
167    }
168
169    fn compute_root_and_proof_inner<'a, const N: usize, Item, Iter>(
170        leaves: Iter,
171        target_index: usize,
172        stack: &mut [[u8; OUT_LEN]; N.ilog2() as usize + 1],
173        proof: &mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
174    ) -> Option<([u8; OUT_LEN], usize)>
175    where
176        [(); N.ilog2() as usize + 1]:,
177        Item: Into<[u8; OUT_LEN]>,
178        Iter: IntoIterator<Item = Item> + 'a,
179    {
180        let mut proof_length = 0;
181        let mut active_levels = 0_u64;
182
183        let mut current_target_level = None;
184        let mut position = target_index;
185
186        for (current_index, hash) in leaves.into_iter().enumerate() {
187            let mut current = hash.into();
188            let mut level = 0;
189
190            if current_index == target_index {
191                // Check if level is active by testing bit (active_levels & (1 << level))
192                while (active_levels & (1 << level)) != 0 {
193                    // If at the target leaf index, need to collect the proof
194                    // SAFETY: Method signature guarantees upper bound of the proof length
195                    unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[level]);
196                    proof_length += 1;
197
198                    current = hash_pair(&stack[level], &current);
199
200                    // Clear the current level
201                    active_levels &= !(1 << level);
202                    level += 1;
203
204                    // Move up the tree
205                    position /= 2;
206                }
207
208                current_target_level = Some(level);
209
210                // Place the current hash at the first inactive level
211                stack[level] = current;
212                // Set bit for level
213                active_levels |= 1 << level;
214            } else {
215                // If at the target leaf index, need to collect the proof
216                while (active_levels & (1 << level)) != 0 {
217                    if current_target_level == Some(level) {
218                        // SAFETY: Method signature guarantees upper bound of the proof length
219                        unsafe { proof.get_unchecked_mut(proof_length) }.write(
220                            if position % 2 == 0 {
221                                current
222                            } else {
223                                stack[level]
224                            },
225                        );
226                        proof_length += 1;
227
228                        current_target_level = Some(level + 1);
229
230                        // Move up the tree
231                        position /= 2;
232                    }
233
234                    current = hash_pair(&stack[level], &current);
235
236                    // Clear the current level
237                    active_levels &= !(1 << level);
238                    level += 1;
239                }
240
241                // Place the current hash at the first inactive level
242                stack[level] = current;
243                // Set bit for level
244                active_levels |= 1 << level;
245            }
246        }
247
248        // `active_levels` here contains the number of leaves after above loop
249        if target_index >= active_levels as usize {
250            // If no leaves were provided
251            return None;
252        }
253
254        let Some(current_target_level) = current_target_level else {
255            // Index not found
256            return None;
257        };
258
259        {
260            let lowest_active_level = active_levels.trailing_zeros() as usize;
261            // Reuse `stack[0]` for resulting value
262            stack[0] = stack[lowest_active_level];
263            // Clear lowest active level
264            active_levels &= !(1 << lowest_active_level);
265        }
266
267        // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
268        // proof hashes
269        let mut merged_peaks = false;
270        loop {
271            let lowest_active_level = active_levels.trailing_zeros() as usize;
272
273            if lowest_active_level == u64::BITS as usize {
274                break;
275            }
276
277            // Clear lowest active level
278            active_levels &= !(1 << lowest_active_level);
279
280            if lowest_active_level > current_target_level
281                || (lowest_active_level == current_target_level
282                    && (position % 2 != 0)
283                    && !merged_peaks)
284            {
285                // SAFETY: Method signature guarantees upper bound of the proof length
286                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[lowest_active_level]);
287                proof_length += 1;
288                merged_peaks = false;
289            } else if lowest_active_level == current_target_level {
290                // SAFETY: Method signature guarantees upper bound of the proof length
291                unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
292                proof_length += 1;
293                merged_peaks = false;
294            } else {
295                // Not collecting proof because of the need to merge peaks of an unbalanced tree
296                merged_peaks = true;
297            }
298
299            // Collect the lowest peak into the proof
300            stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
301
302            position /= 2;
303        }
304
305        Some((stack[0], proof_length))
306    }
307
308    /// Verify a Merkle proof for a leaf at the given index
309    #[inline]
310    pub fn verify(
311        root: &[u8; OUT_LEN],
312        proof: &[[u8; OUT_LEN]],
313        leaf_index: usize,
314        leaf: [u8; OUT_LEN],
315        num_leaves: usize,
316    ) -> bool {
317        if leaf_index >= num_leaves {
318            return false;
319        }
320
321        let mut current = leaf;
322        let mut position = leaf_index;
323        let mut proof_pos = 0;
324        let mut level_size = num_leaves;
325
326        // Rebuild the path to the root
327        while level_size > 1 {
328            let is_left = position % 2 == 0;
329            let is_last = position == level_size - 1;
330
331            if is_left && !is_last {
332                // Left node with a right sibling
333                if proof_pos >= proof.len() {
334                    // Missing sibling
335                    return false;
336                }
337                current = hash_pair(&current, &proof[proof_pos]);
338                proof_pos += 1;
339            } else if !is_left {
340                // Right node with a left sibling
341                if proof_pos >= proof.len() {
342                    // Missing sibling
343                    return false;
344                }
345                current = hash_pair(&proof[proof_pos], &current);
346                proof_pos += 1;
347            } else {
348                // Last node, no sibling, keep current
349            }
350
351            position /= 2;
352            // Size of next level
353            level_size = level_size.div_ceil(2);
354        }
355
356        // Check if proof is fully used and matches root
357        proof_pos == proof.len() && current == *root
358    }
359}