Skip to main content

ab_merkle_tree/
balanced.rs

1use crate::{hash_pair, hash_pair_block, hash_pairs};
2use ab_blake3::{BLOCK_LEN, OUT_LEN};
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem;
7use core::mem::MaybeUninit;
8use core::num::NonZero;
9
10/// Optimal number of blocks for hashing at once to saturate BLAKE3 SIMD on any hardware
11const BATCH_HASH_NUM_BLOCKS: usize = 16;
12/// Number of leaves that corresponds to [`BATCH_HASH_NUM_BLOCKS`]
13const BATCH_HASH_NUM_LEAVES: usize = BATCH_HASH_NUM_BLOCKS * BLOCK_LEN / OUT_LEN;
14
15/// Inner function used in [`BalancedMerkleTree::compute_root_only()`] for stack allocation, only
16/// public due to use in generic bounds
17pub const fn compute_root_only_large_stack_size(n: usize) -> usize {
18    // For small trees the large stack is not used, so the returned value does not matter as long as
19    // it compiles
20    if n < BATCH_HASH_NUM_LEAVES {
21        return 1;
22    }
23
24    (n / BATCH_HASH_NUM_LEAVES).ilog2() as usize + 1
25}
26
27/// Ensuring only supported `N` can be specified for [`BalancedMerkleTree`].
28///
29/// This is essentially a workaround for the current Rust type system constraints that do not allow
30/// a nicer way to do the same thing at compile time.
31pub const fn ensure_supported_n(n: usize) -> usize {
32    assert!(
33        n.is_power_of_two(),
34        "Balanced Merkle Tree must have a number of leaves that is a power of 2"
35    );
36
37    assert!(
38        n > 1,
39        "This Balanced Merkle Tree must have more than one leaf"
40    );
41
42    0
43}
44
45/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
46/// generic parameter.
47///
48/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
49/// identical for both in case the number of leaves is a power of two. For the number of leaves that
50/// is a power of two [`UnbalancedMerkleTree`] is useful when a single proof needs to be generated
51/// and the number of leaves is very large (it can generate proofs with very little RAM usage
52/// compared to this version).
53///
54/// [`UnbalancedMerkleTree`]: crate::unbalanced::UnbalancedMerkleTree
55///
56/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
57/// the elements need to be generated and the whole tree easily fits into memory. It can also be
58/// constructed and proofs can be generated efficiently without heap allocations.
59///
60/// With all parameters of the tree known statically, it results in the most efficient version of
61/// the code being generated for a given set of parameters.
62#[derive(Debug)]
63pub struct BalancedMerkleTree<'a, const N: usize>
64where
65    [(); N - 1]:,
66{
67    leaves: &'a [[u8; OUT_LEN]],
68    // This tree doesn't include leaves because we have them in `leaves` field
69    tree: [[u8; OUT_LEN]; N - 1],
70}
71
72// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
73//  https://github.com/BLAKE3-team/BLAKE3/issues/478
74impl<'a, const N: usize> BalancedMerkleTree<'a, N>
75where
76    [(); N - 1]:,
77    [(); ensure_supported_n(N)]:,
78{
79    /// Create a new tree from a fixed set of elements.
80    ///
81    /// The data structure is statically allocated and might be too large to fit on the stack!
82    /// If that is the case, use `new_boxed()` method.
83    // TODO: Unlock on RISC-V, it started failing since https://github.com/nazar-pc/abundance/pull/551
84    //  for unknown reason
85    #[cfg_attr(
86        all(feature = "no-panic", not(target_arch = "riscv64")),
87        no_panic::no_panic
88    )]
89    pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
90        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
91
92        Self::init_internal(leaves, &mut tree);
93
94        Self {
95            leaves,
96            // SAFETY: Statically guaranteed for all elements to be initialized
97            tree: unsafe { tree.transpose().assume_init() },
98        }
99    }
100
101    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
102    // TODO: Unlock on RISC-V, it started failing since https://github.com/nazar-pc/abundance/pull/551
103    //  for unknown reason
104    #[cfg_attr(
105        all(feature = "no-panic", not(target_arch = "riscv64")),
106        no_panic::no_panic
107    )]
108    pub fn new_in<'b>(
109        instance: &'b mut MaybeUninit<Self>,
110        leaves: &'a [[u8; OUT_LEN]; N],
111    ) -> &'b mut Self {
112        let instance_ptr = instance.as_mut_ptr();
113        // SAFETY: Valid and correctly aligned non-null pointer
114        unsafe {
115            (&raw mut (*instance_ptr).leaves).write(leaves);
116        }
117        let tree = {
118            // SAFETY: Valid and correctly aligned non-null pointer
119            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
120            // SAFETY: Allocated and correctly aligned uninitialized data
121            unsafe {
122                tree_ptr
123                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
124                    .as_mut_unchecked()
125            }
126        };
127
128        Self::init_internal(leaves, tree);
129
130        // SAFETY: Initialized field by field above
131        unsafe { instance.assume_init_mut() }
132    }
133
134    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
135    /// for large trees
136    #[cfg(feature = "alloc")]
137    pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
138        let mut instance = Box::<Self>::new_uninit();
139
140        Self::new_in(&mut instance, leaves);
141
142        // SAFETY: Initialized by constructor above
143        unsafe { instance.assume_init() }
144    }
145
146    // TODO: Unlock on RISC-V, it started failing since https://github.com/nazar-pc/abundance/pull/551
147    //  for unknown reason
148    #[cfg_attr(
149        all(feature = "no-panic", not(target_arch = "riscv64")),
150        no_panic::no_panic
151    )]
152    fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
153        let mut tree_hashes = tree.as_mut_slice();
154        let mut level_hashes = leaves.as_slice();
155
156        while level_hashes.len() > 1 {
157            let num_pairs = level_hashes.len() / 2;
158            let parent_hashes;
159            // SAFETY: The size of the tree is statically known to match the number of leaves and
160            // levels of hashes
161            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
162
163            if parent_hashes.len().is_multiple_of(BATCH_HASH_NUM_BLOCKS) {
164                // SAFETY: Just checked to be a multiple of chunk size and not empty
165                let parent_hashes_chunks =
166                    unsafe { parent_hashes.as_chunks_unchecked_mut::<BATCH_HASH_NUM_BLOCKS>() };
167                for (pairs, hashes) in level_hashes
168                    .as_chunks::<BATCH_HASH_NUM_LEAVES>()
169                    .0
170                    .iter()
171                    .zip(parent_hashes_chunks)
172                {
173                    // TODO: Would be nice to have a convenient method for this:
174                    //  https://github.com/rust-lang/rust/pull/145504#pullrequestreview-3788155275
175                    // SAFETY: Identical layout
176                    let hashes = unsafe {
177                        mem::transmute::<
178                            &mut [MaybeUninit<[u8; OUT_LEN]>; BATCH_HASH_NUM_BLOCKS],
179                            &mut MaybeUninit<[[u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]>,
180                        >(hashes)
181                    };
182
183                    // TODO: This memory copy is unfortunate, make hashing write into this memory
184                    //  directly once blake3 API improves
185                    hashes.write(hash_pairs(pairs));
186                }
187            } else {
188                for (pair, parent_hash) in level_hashes
189                    .as_chunks()
190                    .0
191                    .iter()
192                    .zip(parent_hashes.iter_mut())
193                {
194                    // SAFETY: Same size and alignment
195                    let pair = unsafe {
196                        mem::transmute::<&[[u8; OUT_LEN]; BLOCK_LEN / OUT_LEN], &[u8; BLOCK_LEN]>(
197                            pair,
198                        )
199                    };
200                    parent_hash.write(hash_pair_block(pair));
201                }
202            }
203
204            // SAFETY: Just initialized
205            level_hashes = unsafe { parent_hashes.assume_init_ref() };
206        }
207    }
208
209    // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
210    /// Compute Merkle Tree root.
211    ///
212    /// This is functionally equivalent to creating an instance first and calling [`Self::root()`]
213    /// method, but is faster and avoids heap allocation when root is the only thing that is needed.
214    #[inline]
215    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
216    pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
217    where
218        [(); N.ilog2() as usize + 1]:,
219        [(); compute_root_only_large_stack_size(N)]:,
220    {
221        // Special case for small trees below optimal SIMD width
222        match N {
223            2 => {
224                let [root] = hash_pairs(leaves);
225
226                return root;
227            }
228            4 => {
229                let hashes = hash_pairs::<2, _>(leaves);
230                let [root] = hash_pairs(&hashes);
231
232                return root;
233            }
234            8 => {
235                let hashes = hash_pairs::<4, _>(leaves);
236                let hashes = hash_pairs::<2, _>(&hashes);
237                let [root] = hash_pairs(&hashes);
238
239                return root;
240            }
241            16 => {
242                let hashes = hash_pairs::<8, _>(leaves);
243                let hashes = hash_pairs::<4, _>(&hashes);
244                let hashes = hash_pairs::<2, _>(&hashes);
245                let [root] = hash_pairs(&hashes);
246
247                return root;
248            }
249            _ => {
250                // We know this is the case
251                assert!(N >= BATCH_HASH_NUM_LEAVES);
252            }
253        }
254
255        // Stack of intermediate nodes per tree level. The logic here is the same as with a small
256        // tree above, except we store `BATCH_HASH_NUM_BLOCKS` hashes per level and do a
257        // post-processing step at the very end to collapse them into a single root hash.
258        let mut stack =
259            [[[0u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]; compute_root_only_large_stack_size(N)];
260
261        // This variable allows reusing and reducing stack usage instead of having a separate
262        // `current` variable
263        let mut parent_current = [[0u8; OUT_LEN]; BATCH_HASH_NUM_LEAVES];
264        for (num_chunks, chunk_leaves) in leaves
265            .as_chunks::<BATCH_HASH_NUM_LEAVES>()
266            .0
267            .iter()
268            .enumerate()
269        {
270            let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
271
272            let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(chunk_leaves);
273            current_half.copy_from_slice(&current);
274
275            // Every bit set to `1` corresponds to an active Merkle Tree level
276            let lowest_active_levels = num_chunks.trailing_ones() as usize;
277            for parent in &mut stack[..lowest_active_levels] {
278                let parent_half = &mut parent_current[..BATCH_HASH_NUM_BLOCKS];
279                parent_half.copy_from_slice(parent);
280
281                let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(&parent_current);
282
283                let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
284                current_half.copy_from_slice(&current);
285            }
286
287            let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
288
289            // Place freshly computed 8 hashes into the first inactive level
290            stack[lowest_active_levels].copy_from_slice(current_half);
291        }
292
293        let hashes = &mut stack[compute_root_only_large_stack_size(N) - 1];
294        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 2 }, _>(hashes);
295        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 4 }, _>(&hashes);
296        let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 8 }, _>(&hashes);
297        let [root] = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 16 }, _>(&hashes);
298
299        root
300    }
301
302    /// Get the root of Merkle Tree
303    #[inline]
304    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
305    pub fn root(&self) -> [u8; OUT_LEN] {
306        *self
307            .tree
308            .last()
309            .or(self.leaves.last())
310            .expect("There is always at least one leaf hash; qed")
311    }
312
313    /// Iterator over proofs in the same order as provided leaf hashes
314    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
315    pub fn all_proofs(&self) -> ProofsIterator<'_, N>
316    where
317        [(); N.ilog2() as usize]:,
318    {
319        ProofsIterator {
320            leaves: self.leaves,
321            tree: &self.tree,
322            leaf_index: 0,
323            len: N,
324        }
325    }
326
327    /// Verify previously generated proof
328    #[inline]
329    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
330    pub fn verify(
331        root: &[u8; OUT_LEN],
332        proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
333        leaf_index: usize,
334        leaf: [u8; OUT_LEN],
335    ) -> bool
336    where
337        [(); N.ilog2() as usize]:,
338    {
339        if leaf_index >= N {
340            return false;
341        }
342
343        let mut computed_root = leaf;
344
345        let mut position = leaf_index;
346        for hash in proof {
347            computed_root = if position.is_multiple_of(2) {
348                hash_pair(&computed_root, hash)
349            } else {
350                hash_pair(hash, &computed_root)
351            };
352
353            position /= 2;
354        }
355
356        root == &computed_root
357    }
358}
359
360/// Iterator over proofs for a balanced Merkle tree
361#[derive(Debug)]
362pub struct ProofsIterator<'a, const N: usize>
363where
364    [(); N.ilog2() as usize]:,
365    [(); N - 1]:,
366    [(); ensure_supported_n(N)]:,
367{
368    pub(super) leaves: &'a [[u8; OUT_LEN]],
369    pub(super) tree: &'a [[u8; OUT_LEN]; N - 1],
370    pub(super) leaf_index: usize,
371    pub(super) len: usize,
372}
373
374impl<'a, const N: usize> Iterator for ProofsIterator<'a, N>
375where
376    [(); N.ilog2() as usize]:,
377    [(); N - 1]:,
378    [(); ensure_supported_n(N)]:,
379{
380    type Item = [[u8; OUT_LEN]; N.ilog2() as usize];
381
382    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
383    fn next(&mut self) -> Option<Self::Item> {
384        if self.len == 0 {
385            return None;
386        }
387        self.len -= 1;
388
389        let index = self.leaf_index;
390        self.leaf_index += 1;
391
392        // The line below is a more efficient branchless version of this:
393        // let sibling_index = if index % 2 == 0 {
394        //     index + 1
395        // } else {
396        //     index - 1
397        // };
398        let sibling_index = index ^ 1;
399        // SAFETY: `index < N` guaranteed by `len` tracking
400        let sibling_hash = *unsafe { self.leaves.get_unchecked(sibling_index) };
401
402        let mut proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
403        proof[0].write(sibling_hash);
404
405        // Part that is shared between left and right leaf proofs
406        let shared_proof = &mut proof[1..];
407
408        let mut tree_hashes = self.tree.as_slice();
409        let mut parent_position = index / 2;
410        let mut parent_level_size = N / 2;
411
412        for hash in shared_proof {
413            let parent_other_position = parent_position ^ 1;
414
415            // SAFETY: Statically guaranteed to be present by constructor
416            let other_hash = unsafe { tree_hashes.get_unchecked(parent_other_position) };
417            hash.write(*other_hash);
418            tree_hashes = &tree_hashes[parent_level_size..];
419
420            parent_position /= 2;
421            parent_level_size /= 2;
422        }
423
424        // SAFETY: Just initialized
425        Some(unsafe { proof.transpose().assume_init() })
426    }
427
428    #[inline(always)]
429    fn size_hint(&self) -> (usize, Option<usize>) {
430        (self.len, Some(self.len))
431    }
432
433    #[inline(always)]
434    fn count(self) -> usize {
435        self.len
436    }
437
438    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
439    fn last(mut self) -> Option<Self::Item> {
440        if self.len == 0 {
441            return None;
442        }
443        self.leaf_index = N - 1;
444        self.len = 1;
445        self.next()
446    }
447
448    #[inline(always)]
449    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
450        let advance = n.min(self.len);
451        self.leaf_index += advance;
452        self.len -= advance;
453        NonZero::new(n - advance).map_or(Ok(()), Err)
454    }
455
456    #[inline(always)]
457    fn nth(&mut self, n: usize) -> Option<Self::Item> {
458        match self.advance_by(n) {
459            Ok(()) => self.next(),
460            Err(_) => None,
461        }
462    }
463}
464
465impl<'a, const N: usize> ExactSizeIterator for ProofsIterator<'a, N>
466where
467    [(); N.ilog2() as usize]:,
468    [(); N - 1]:,
469    [(); ensure_supported_n(N)]:,
470{
471    #[inline(always)]
472    fn len(&self) -> usize {
473        self.len
474    }
475}
476
477// SAFETY: size_hint is always exact
478unsafe impl<'a, const N: usize> TrustedLen for ProofsIterator<'a, N>
479where
480    [(); N.ilog2() as usize]:,
481    [(); N - 1]:,
482    [(); ensure_supported_n(N)]:,
483{
484}