ab_merkle_tree/
balanced.rs

1use crate::{hash_pair, hash_pair_block, hash_pairs};
2use ab_blake3::{BLOCK_LEN, OUT_LEN};
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem;
7use core::mem::MaybeUninit;
8use core::num::NonZero;
9
10/// Optimal number of blocks for hashing at once to saturate BLAKE3 SIMD on any hardware
11const BATCH_HASH_NUM_BLOCKS: usize = 16;
12/// Number of leaves that corresponds to [`BATCH_HASH_NUM_BLOCKS`]
13const BATCH_HASH_NUM_LEAVES: usize = BATCH_HASH_NUM_BLOCKS * BLOCK_LEN / OUT_LEN;
14
15/// Inner function used in [`BalancedMerkleTree::compute_root_only()`] for stack allocation, only
16/// public due to use in generic bounds
17pub const fn compute_root_only_large_stack_size(n: usize) -> usize {
18    // For small trees the large stack is not used, so the returned value does not matter as long as
19    // it compiles
20    if n < BATCH_HASH_NUM_LEAVES {
21        return 1;
22    }
23
24    (n / BATCH_HASH_NUM_LEAVES).ilog2() as usize + 1
25}
26
27/// Ensuring only supported `N` can be specified for [`BalancedMerkleTree`].
28///
29/// This is essentially a workaround for the current Rust type system constraints that do not allow
30/// a nicer way to do the same thing at compile time.
31pub const fn ensure_supported_n(n: usize) -> usize {
32    assert!(
33        n.is_power_of_two(),
34        "Balanced Merkle Tree must have a number of leaves that is a power of 2"
35    );
36
37    assert!(
38        n > 1,
39        "This Balanced Merkle Tree must have more than one leaf"
40    );
41
42    0
43}
44
45/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
46/// generic parameter.
47///
48/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
49/// identical for both in case the number of leaves is a power of two. For the number of leaves that
50/// is a power of two [`UnbalancedMerkleTree`] is useful when a single proof needs to be generated
51/// and the number of leaves is very large (it can generate proofs with very little RAM usage
52/// compared to this version).
53///
54/// [`UnbalancedMerkleTree`]: crate::unbalanced::UnbalancedMerkleTree
55///
56/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
57/// the elements need to be generated and the whole tree easily fits into memory. It can also be
58/// constructed and proofs can be generated efficiently without heap allocations.
59///
60/// With all parameters of the tree known statically, it results in the most efficient version of
61/// the code being generated for a given set of parameters.
62#[derive(Debug)]
63pub struct BalancedMerkleTree<'a, const N: usize>
64where
65    [(); N - 1]:,
66{
67    leaves: &'a [[u8; OUT_LEN]],
68    // This tree doesn't include leaves because we have them in `leaves` field
69    tree: [[u8; OUT_LEN]; N - 1],
70}
71
72// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
73//  https://github.com/BLAKE3-team/BLAKE3/issues/478
74impl<'a, const N: usize> BalancedMerkleTree<'a, N>
75where
76    [(); N - 1]:,
77    [(); ensure_supported_n(N)]:,
78{
79    /// Create a new tree from a fixed set of elements.
80    ///
81    /// The data structure is statically allocated and might be too large to fit on the stack!
82    /// If that is the case, use `new_boxed()` method.
83    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
84    pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
85        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
86
87        Self::init_internal(leaves, &mut tree);
88
89        Self {
90            leaves,
91            // SAFETY: Statically guaranteed for all elements to be initialized
92            tree: unsafe { tree.transpose().assume_init() },
93        }
94    }
95
96    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
97    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
98    pub fn new_in<'b>(
99        instance: &'b mut MaybeUninit<Self>,
100        leaves: &'a [[u8; OUT_LEN]; N],
101    ) -> &'b mut Self {
102        let instance_ptr = instance.as_mut_ptr();
103        // SAFETY: Valid and correctly aligned non-null pointer
104        unsafe {
105            (&raw mut (*instance_ptr).leaves).write(leaves);
106        }
107        let tree = {
108            // SAFETY: Valid and correctly aligned non-null pointer
109            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
110            // SAFETY: Allocated and correctly aligned uninitialized data
111            unsafe {
112                tree_ptr
113                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
114                    .as_mut_unchecked()
115            }
116        };
117
118        Self::init_internal(leaves, tree);
119
120        // SAFETY: Initialized field by field above
121        unsafe { instance.assume_init_mut() }
122    }
123
124    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
125    /// for large trees
126    #[cfg(feature = "alloc")]
127    pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
128        let mut instance = Box::<Self>::new_uninit();
129
130        Self::new_in(&mut instance, leaves);
131
132        // SAFETY: Initialized by constructor above
133        unsafe { instance.assume_init() }
134    }
135
136    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
137    fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
138        let mut tree_hashes = tree.as_mut_slice();
139        let mut level_hashes = leaves.as_slice();
140
141        while level_hashes.len() > 1 {
142            let num_pairs = level_hashes.len() / 2;
143            let parent_hashes;
144            // SAFETY: The size of the tree is statically known to match the number of leaves and
145            // levels of hashes
146            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
147
148            if parent_hashes.len().is_multiple_of(BATCH_HASH_NUM_BLOCKS) {
149                // SAFETY: Just checked to be a multiple of chunk size and not empty
150                let parent_hashes_chunks =
151                    unsafe { parent_hashes.as_chunks_unchecked_mut::<BATCH_HASH_NUM_BLOCKS>() };
152                for (pairs, hashes) in level_hashes
153                    .as_chunks::<BATCH_HASH_NUM_LEAVES>()
154                    .0
155                    .iter()
156                    .zip(parent_hashes_chunks)
157                {
158                    // TODO: Would be nice to have a convenient method for this:
159                    //  https://github.com/rust-lang/rust/issues/96097#issuecomment-3133515169
160                    // SAFETY: Identical layout
161                    let hashes = unsafe {
162                        mem::transmute::<
163                            &mut [MaybeUninit<[u8; OUT_LEN]>; BATCH_HASH_NUM_BLOCKS],
164                            &mut MaybeUninit<[[u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]>,
165                        >(hashes)
166                    };
167
168                    // TODO: This memory copy is unfortunate, make hashing write into this memory
169                    //  directly once blake3 API improves
170                    hashes.write(hash_pairs(pairs));
171                }
172            } else {
173                for (pair, parent_hash) in level_hashes
174                    .as_chunks()
175                    .0
176                    .iter()
177                    .zip(parent_hashes.iter_mut())
178                {
179                    // SAFETY: Same size and alignment
180                    let pair = unsafe {
181                        mem::transmute::<&[[u8; OUT_LEN]; BLOCK_LEN / OUT_LEN], &[u8; BLOCK_LEN]>(
182                            pair,
183                        )
184                    };
185                    parent_hash.write(hash_pair_block(pair));
186                }
187            }
188
189            // SAFETY: Just initialized
190            level_hashes = unsafe { parent_hashes.assume_init_ref() };
191        }
192    }
193
194    // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
195    /// Compute Merkle Tree root.
196    ///
197    /// This is functionally equivalent to creating an instance first and calling [`Self::root()`]
198    /// method, but is faster and avoids heap allocation when root is the only thing that is needed.
199    #[inline]
200    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
201    pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
202    where
203        [(); N.ilog2() as usize + 1]:,
204        [(); compute_root_only_large_stack_size(N)]:,
205    {
206        // Special case for small trees below optimal SIMD width
207        match N {
208            2 => {
209                let [root] = hash_pairs(leaves);
210
211                return root;
212            }
213            4 => {
214                let hashes = hash_pairs::<2, _>(leaves);
215                let [root] = hash_pairs(&hashes);
216
217                return root;
218            }
219            8 => {
220                let hashes = hash_pairs::<4, _>(leaves);
221                let hashes = hash_pairs::<2, _>(&hashes);
222                let [root] = hash_pairs(&hashes);
223
224                return root;
225            }
226            16 => {
227                let hashes = hash_pairs::<8, _>(leaves);
228                let hashes = hash_pairs::<4, _>(&hashes);
229                let hashes = hash_pairs::<2, _>(&hashes);
230                let [root] = hash_pairs(&hashes);
231
232                return root;
233            }
234            _ => {
235                // We know this is the case
236                assert!(N >= BATCH_HASH_NUM_LEAVES);
237            }
238        }
239
240        // Stack of intermediate nodes per tree level. The logic here is the same as with a small
241        // tree above, except we store `BATCH_HASH_NUM_BLOCKS` hashes per level and do a
242        // post-processing step at the very end to collapse them into a single root hash.
243        let mut stack =
244            [[[0u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]; compute_root_only_large_stack_size(N)];
245
246        // This variable allows reusing and reducing stack usage instead of having a separate
247        // `current` variable
248        let mut parent_current = [[0u8; OUT_LEN]; BATCH_HASH_NUM_LEAVES];
249        for (num_chunks, chunk_leaves) in leaves
250            .as_chunks::<BATCH_HASH_NUM_LEAVES>()
251            .0
252            .iter()
253            .enumerate()
254        {
255            let (_parent_half, current_half) = parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
256
257            let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(chunk_leaves);
258            current_half.copy_from_slice(&current);
259
260            // Every bit set to `1` corresponds to an active Merkle Tree level
261            let lowest_active_levels = num_chunks.trailing_ones() as usize;
262            for parent in &mut stack[..lowest_active_levels] {
263                let (parent_half, _current_half) =
264                    parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
265                parent_half.copy_from_slice(parent);
266
267                let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(&parent_current);
268
269                let (_parent_half, current_half) =
270                    parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
271                current_half.copy_from_slice(&current);
272            }
273
274            let (_parent_half, current_half) = parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
275
276            // Place freshly computed 8 hashes into the first inactive level
277            stack[lowest_active_levels].copy_from_slice(current_half);
278        }
279
280        let hashes = &mut stack[compute_root_only_large_stack_size(N) - 1];
281        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 2 }, _>(hashes);
282        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 4 }, _>(&hashes);
283        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 8 }, _>(&hashes);
284        let [root] = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 16 }, _>(&hashes);
285
286        root
287    }
288
289    /// Get the root of Merkle Tree
290    #[inline]
291    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
292    pub fn root(&self) -> [u8; OUT_LEN] {
293        *self
294            .tree
295            .last()
296            .or(self.leaves.last())
297            .expect("There is always at least one leaf hash; qed")
298    }
299
300    /// Iterator over proofs in the same order as provided leaf hashes
301    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
302    pub fn all_proofs(
303        &self,
304    ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
305    where
306        [(); N.ilog2() as usize]:,
307    {
308        let iter = self.leaves.as_chunks().0.iter().enumerate().flat_map(
309            |(pair_index, &[left_hash, right_hash])| {
310                let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
311                left_proof[0].write(right_hash);
312
313                let left_proof = {
314                    let (_, shared_proof) = left_proof.split_at_mut(1);
315
316                    let mut tree_hashes = self.tree.as_slice();
317                    let mut parent_position = pair_index;
318                    let mut parent_level_size = N / 2;
319
320                    for hash in shared_proof {
321                        // The line below is a more efficient branchless version of this:
322                        // let parent_other_position = if parent_position % 2 == 0 {
323                        //     parent_position + 1
324                        // } else {
325                        //     parent_position - 1
326                        // };
327                        let parent_other_position = parent_position ^ 1;
328
329                        // SAFETY: Statically guaranteed to be present by constructor
330                        let other_hash =
331                            unsafe { tree_hashes.get_unchecked(parent_other_position) };
332                        hash.write(*other_hash);
333                        (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
334
335                        parent_position /= 2;
336                        parent_level_size /= 2;
337                    }
338
339                    // SAFETY: Just initialized
340                    unsafe { left_proof.transpose().assume_init() }
341                };
342
343                let mut right_proof = left_proof;
344                right_proof[0] = left_hash;
345
346                [left_proof, right_proof]
347            },
348        );
349
350        ProofsIterator { iter, len: N }
351    }
352
353    /// Verify previously generated proof
354    #[inline]
355    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
356    pub fn verify(
357        root: &[u8; OUT_LEN],
358        proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
359        leaf_index: usize,
360        leaf: [u8; OUT_LEN],
361    ) -> bool
362    where
363        [(); N.ilog2() as usize]:,
364    {
365        if leaf_index >= N {
366            return false;
367        }
368
369        let mut computed_root = leaf;
370
371        let mut position = leaf_index;
372        for hash in proof {
373            computed_root = if position.is_multiple_of(2) {
374                hash_pair(&computed_root, hash)
375            } else {
376                hash_pair(hash, &computed_root)
377            };
378
379            position /= 2;
380        }
381
382        root == &computed_root
383    }
384}
385
386struct ProofsIterator<Iter> {
387    iter: Iter,
388    len: usize,
389}
390
391impl<Iter> Iterator for ProofsIterator<Iter>
392where
393    Iter: Iterator,
394{
395    type Item = Iter::Item;
396
397    #[inline(always)]
398    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
399    fn next(&mut self) -> Option<Self::Item> {
400        let item = self.iter.next();
401        self.len = self.len.saturating_sub(1);
402        item
403    }
404
405    #[inline(always)]
406    fn size_hint(&self) -> (usize, Option<usize>) {
407        (self.len, Some(self.len))
408    }
409
410    #[inline(always)]
411    fn count(self) -> usize
412    where
413        Self: Sized,
414    {
415        self.len
416    }
417
418    #[inline(always)]
419    fn last(self) -> Option<Self::Item>
420    where
421        Self: Sized,
422    {
423        self.iter.last()
424    }
425
426    #[inline(always)]
427    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
428        self.len = self.len.saturating_sub(n);
429        self.iter.advance_by(n)
430    }
431
432    #[inline(always)]
433    fn nth(&mut self, n: usize) -> Option<Self::Item> {
434        self.len = self.len.saturating_sub(n.saturating_add(1));
435        self.iter.nth(n)
436    }
437}
438
439impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
440where
441    Iter: Iterator,
442{
443    #[inline(always)]
444    fn len(&self) -> usize {
445        self.len
446    }
447}
448
449unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}