ab_merkle_tree/
balanced_hashed.rs

1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4use blake3::OUT_LEN;
5use core::iter;
6use core::iter::TrustedLen;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
10/// generic parameter.
11///
12/// This can be considered a general case of [`UnbalancedHashedMerkleTree`]. The root and proofs are
13/// identical for both in case the number of leaves is a power of two. For the number of leaves that
14/// is a power of two [`UnbalancedHashedMerkleTree`] is useful when a single proof needs to be
15/// generated and the number of leaves is very large (it can generate proofs with very little RAM
16/// usage compared to this version).
17///
18/// [`UnbalancedHashedMerkleTree`]: crate::unbalanced_hashed::UnbalancedHashedMerkleTree
19///
20/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
21/// the elements need to be generated and the whole tree easily fits into memory. It can also be
22/// constructed and proofs can be generated efficiently without heap allocations.
23///
24/// With all parameters of the tree known statically, it results in the most efficient version of
25/// the code being generated for a given set of parameters.
26#[derive(Debug)]
27pub struct BalancedHashedMerkleTree<'a, const N: usize>
28where
29    [(); N - 1]:,
30{
31    leaves: &'a [[u8; OUT_LEN]],
32    // This tree doesn't include leaves because we know the size
33    tree: [[u8; OUT_LEN]; N - 1],
34}
35
36// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
37//  https://github.com/BLAKE3-team/BLAKE3/issues/478
38impl<'a, const N: usize> BalancedHashedMerkleTree<'a, N>
39where
40    [(); N - 1]:,
41{
42    /// Create a new tree from a fixed set of elements.
43    ///
44    /// The data structure is statically allocated and might be too large to fit on the stack!
45    /// If that is the case, use `new_boxed()` method.
46    pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
47        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
48
49        Self::init_internal(leaves, &mut tree);
50
51        Self {
52            leaves,
53            // SAFETY: Statically guaranteed for all elements to be initialized
54            tree: unsafe { tree.transpose().assume_init() },
55        }
56    }
57
58    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
59    pub fn new_in<'b>(
60        instance: &'b mut MaybeUninit<Self>,
61        leaves: &'a [[u8; OUT_LEN]; N],
62    ) -> &'b mut Self {
63        let instance_ptr = instance.as_mut_ptr();
64        // SAFETY: Valid and correctly aligned non-null pointer
65        unsafe {
66            (&raw mut (*instance_ptr).leaves).write(leaves);
67        }
68        let tree = {
69            // SAFETY: Valid and correctly aligned non-null pointer
70            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
71            // SAFETY: Allocated and correctly aligned uninitialized data
72            unsafe {
73                tree_ptr
74                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
75                    .as_mut_unchecked()
76            }
77        };
78
79        Self::init_internal(leaves, tree);
80
81        // SAFETY: Initialized field by field above
82        unsafe { instance.assume_init_mut() }
83    }
84
85    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
86    /// for large trees
87    #[cfg(feature = "alloc")]
88    pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
89        let mut instance = Box::<Self>::new_uninit();
90
91        Self::new_in(&mut instance, leaves);
92
93        // SAFETY: Initialized by constructor above
94        unsafe { instance.assume_init() }
95    }
96
97    fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
98        let mut tree_hashes = tree.as_mut_slice();
99        let mut level_hashes = leaves.as_slice();
100
101        let mut pair = [0u8; OUT_LEN * 2];
102        while level_hashes.len() > 1 {
103            let num_pairs = level_hashes.len() / 2;
104            let parent_hashes;
105            // SAFETY: The size of the tree is statically known to match the number of leaves and
106            // levels of hashes
107            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
108
109            for pair_index in 0..num_pairs {
110                // SAFETY: Entry is statically known to be present
111                let left_hash = unsafe { level_hashes.get_unchecked(pair_index * 2) };
112                // SAFETY: Entry is statically known to be present
113                let right_hash = unsafe { level_hashes.get_unchecked(pair_index * 2 + 1) };
114                // SAFETY: Entry is statically known to be present
115                let parent_hash = unsafe { parent_hashes.get_unchecked_mut(pair_index) };
116
117                pair[..OUT_LEN].copy_from_slice(left_hash);
118                pair[OUT_LEN..].copy_from_slice(right_hash);
119
120                parent_hash.write(hash_pair(left_hash, right_hash));
121            }
122
123            // SAFETY: Just initialized
124            level_hashes = unsafe { parent_hashes.assume_init_ref() };
125        }
126    }
127
128    /// Compute Merkle Tree Root.
129    ///
130    /// This is functionally equivalent to creating an instance first and calling [`Self::root()`]
131    /// method, but is faster and avoids heap allocation when root is the only thing that is needed.
132    #[inline]
133    pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
134    where
135        [(); N.ilog2() as usize + 1]:,
136    {
137        if leaves.len() == 1 {
138            return leaves[0];
139        }
140
141        // Stack of intermediate nodes per tree level
142        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
143        // Bitmask: bit `i = 1` if level `i` is active
144        let mut active_levels = 0_u32;
145
146        for &hash in leaves {
147            let mut current = hash;
148            let mut level = 0;
149
150            // Check if level is active by testing bit (active_levels & (1 << level))
151            while active_levels & (1 << level) != 0 {
152                current = hash_pair(&stack[level], &current);
153
154                // Clear the current level
155                active_levels &= !(1 << level);
156                level += 1;
157            }
158
159            // Place the current hash at the first inactive level
160            stack[level] = current;
161            // Set bit for level
162            active_levels |= 1 << level;
163        }
164
165        stack[N.ilog2() as usize]
166    }
167
168    /// Get the root of Merkle Tree.
169    ///
170    /// In case a tree contains a single leaf hash, that leaf hash is returned.
171    #[inline]
172    pub fn root(&self) -> [u8; OUT_LEN] {
173        *self
174            .tree
175            .last()
176            .or(self.leaves.last())
177            .expect("There is always at least one leaf hash; qed")
178    }
179
180    /// Iterator over proofs in the same order as provided leaf hashes
181    pub fn all_proofs(
182        &self,
183    ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
184    where
185        [(); N.ilog2() as usize]:,
186    {
187        let iter = self
188            .leaves
189            .array_chunks()
190            .enumerate()
191            .flat_map(|(pair_index, &[left_hash, right_hash])| {
192                let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
193                left_proof[0].write(right_hash);
194
195                let left_proof = {
196                    let (_, shared_proof) = left_proof.split_at_mut(1);
197
198                    let mut tree_hashes = self.tree.as_slice();
199                    let mut parent_position = pair_index;
200                    let mut parent_level_size = N / 2;
201
202                    for hash in shared_proof {
203                        let parent_other_position = if parent_position % 2 == 0 {
204                            parent_position + 1
205                        } else {
206                            parent_position - 1
207                        };
208                        // SAFETY: Statically guaranteed to be present by constructor
209                        let other_hash =
210                            unsafe { tree_hashes.get_unchecked(parent_other_position) };
211                        hash.write(*other_hash);
212                        (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
213
214                        parent_position /= 2;
215                        parent_level_size /= 2;
216                    }
217
218                    // SAFETY: Just initialized
219                    unsafe { left_proof.transpose().assume_init() }
220                };
221
222                let mut right_proof = left_proof;
223                right_proof[0] = left_hash;
224
225                [left_proof, right_proof]
226            })
227            // Special case for a single leaf tree to make sure proof is returned, even if it is
228            // empty
229            .chain({
230                let mut returned = false;
231
232                iter::from_fn(move || {
233                    if N == 1 && !returned {
234                        returned = true;
235                        Some([[0; OUT_LEN]; N.ilog2() as usize])
236                    } else {
237                        None
238                    }
239                })
240            });
241
242        ProofsIterator { iter, len: N }
243    }
244
245    /// Verify previously generated proof
246    #[inline]
247    pub fn verify(
248        root: &[u8; OUT_LEN],
249        proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
250        leaf_index: usize,
251        leaf: [u8; OUT_LEN],
252    ) -> bool
253    where
254        [(); N.ilog2() as usize]:,
255    {
256        if leaf_index >= N {
257            return false;
258        }
259
260        let mut computed_root = leaf;
261
262        let mut position = leaf_index;
263        for hash in proof {
264            computed_root = if position % 2 == 0 {
265                hash_pair(&computed_root, hash)
266            } else {
267                hash_pair(hash, &computed_root)
268            };
269
270            position /= 2;
271        }
272
273        root == &computed_root
274    }
275}
276
277struct ProofsIterator<Iter> {
278    iter: Iter,
279    len: usize,
280}
281
282impl<Iter> Iterator for ProofsIterator<Iter>
283where
284    Iter: Iterator,
285{
286    type Item = Iter::Item;
287
288    #[inline(always)]
289    fn next(&mut self) -> Option<Self::Item> {
290        let item = self.iter.next();
291        self.len = self.len.saturating_sub(1);
292        item
293    }
294
295    #[inline(always)]
296    fn size_hint(&self) -> (usize, Option<usize>) {
297        (self.len, Some(self.len))
298    }
299
300    #[inline(always)]
301    fn count(self) -> usize
302    where
303        Self: Sized,
304    {
305        self.len
306    }
307}
308
309impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
310where
311    Iter: Iterator,
312{
313    #[inline(always)]
314    fn len(&self) -> usize {
315        self.len
316    }
317}
318
319unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}