ab_merkle_tree/
balanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter;
6use core::iter::TrustedLen;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
10/// generic parameter.
11///
12/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
13/// identical for both in case the number of leaves is a power of two. For the number of leaves that
14/// is a power of two [`UnbalancedMerkleTree`] is useful when a single proof needs to be generated
15/// and the number of leaves is very large (it can generate proofs with very little RAM usage
16/// compared to this version).
17///
18/// [`UnbalancedMerkleTree`]: crate::unbalanced::UnbalancedMerkleTree
19///
20/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
21/// the elements need to be generated and the whole tree easily fits into memory. It can also be
22/// constructed and proofs can be generated efficiently without heap allocations.
23///
24/// With all parameters of the tree known statically, it results in the most efficient version of
25/// the code being generated for a given set of parameters.
26#[derive(Debug)]
27pub struct BalancedMerkleTree<'a, const N: usize>
28where
29    [(); N - 1]:,
30{
31    leaves: &'a [[u8; OUT_LEN]],
32    // This tree doesn't include leaves because we know the size
33    tree: [[u8; OUT_LEN]; N - 1],
34}
35
36// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
37//  https://github.com/BLAKE3-team/BLAKE3/issues/478
38impl<'a, const N: usize> BalancedMerkleTree<'a, N>
39where
40    [(); N - 1]:,
41{
42    /// Create a new tree from a fixed set of elements.
43    ///
44    /// The data structure is statically allocated and might be too large to fit on the stack!
45    /// If that is the case, use `new_boxed()` method.
46    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
47    pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
48        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
49
50        Self::init_internal(leaves, &mut tree);
51
52        Self {
53            leaves,
54            // SAFETY: Statically guaranteed for all elements to be initialized
55            tree: unsafe { tree.transpose().assume_init() },
56        }
57    }
58
59    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
60    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
61    pub fn new_in<'b>(
62        instance: &'b mut MaybeUninit<Self>,
63        leaves: &'a [[u8; OUT_LEN]; N],
64    ) -> &'b mut Self {
65        let instance_ptr = instance.as_mut_ptr();
66        // SAFETY: Valid and correctly aligned non-null pointer
67        unsafe {
68            (&raw mut (*instance_ptr).leaves).write(leaves);
69        }
70        let tree = {
71            // SAFETY: Valid and correctly aligned non-null pointer
72            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
73            // SAFETY: Allocated and correctly aligned uninitialized data
74            unsafe {
75                tree_ptr
76                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
77                    .as_mut_unchecked()
78            }
79        };
80
81        Self::init_internal(leaves, tree);
82
83        // SAFETY: Initialized field by field above
84        unsafe { instance.assume_init_mut() }
85    }
86
87    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
88    /// for large trees
89    #[cfg(feature = "alloc")]
90    pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
91        let mut instance = Box::<Self>::new_uninit();
92
93        Self::new_in(&mut instance, leaves);
94
95        // SAFETY: Initialized by constructor above
96        unsafe { instance.assume_init() }
97    }
98
99    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
100    fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
101        let mut tree_hashes = tree.as_mut_slice();
102        let mut level_hashes = leaves.as_slice();
103
104        let mut pair = [0u8; OUT_LEN * 2];
105        while level_hashes.len() > 1 {
106            let num_pairs = level_hashes.len() / 2;
107            let parent_hashes;
108            // SAFETY: The size of the tree is statically known to match the number of leaves and
109            // levels of hashes
110            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
111
112            for ([left_hash, right_hash], parent_hash) in
113                level_hashes.array_chunks().zip(parent_hashes.iter_mut())
114            {
115                pair[..OUT_LEN].copy_from_slice(left_hash);
116                pair[OUT_LEN..].copy_from_slice(right_hash);
117
118                parent_hash.write(hash_pair(left_hash, right_hash));
119            }
120
121            // SAFETY: Just initialized
122            level_hashes = unsafe { parent_hashes.assume_init_ref() };
123        }
124    }
125
126    /// Compute Merkle Tree root.
127    ///
128    /// This is functionally equivalent to creating an instance first and calling [`Self::root()`]
129    /// method, but is faster and avoids heap allocation when root is the only thing that is needed.
130    #[inline]
131    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
132    pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
133    where
134        [(); N.ilog2() as usize + 1]:,
135    {
136        if leaves.len() == 1 {
137            return leaves[0];
138        }
139
140        // Stack of intermediate nodes per tree level
141        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
142
143        for (num_leaves, &hash) in leaves.iter().enumerate() {
144            let mut current = hash;
145
146            // Every bit set to `1` corresponds to an active Merkle Tree level
147            let lowest_active_levels = num_leaves.trailing_ones() as usize;
148            for item in stack.iter().take(lowest_active_levels) {
149                current = hash_pair(item, &current);
150            }
151
152            // Place the current hash at the first inactive level
153            stack[lowest_active_levels] = current;
154        }
155
156        stack[N.ilog2() as usize]
157    }
158
159    /// Get the root of Merkle Tree.
160    ///
161    /// In case a tree contains a single leaf hash, that leaf hash is returned.
162    #[inline]
163    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
164    pub fn root(&self) -> [u8; OUT_LEN] {
165        *self
166            .tree
167            .last()
168            .or(self.leaves.last())
169            .expect("There is always at least one leaf hash; qed")
170    }
171
172    /// Iterator over proofs in the same order as provided leaf hashes
173    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
174    pub fn all_proofs(
175        &self,
176    ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
177    where
178        [(); N.ilog2() as usize]:,
179    {
180        let iter = self
181            .leaves
182            .array_chunks()
183            .enumerate()
184            .flat_map(|(pair_index, &[left_hash, right_hash])| {
185                let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
186                left_proof[0].write(right_hash);
187
188                let left_proof = {
189                    let (_, shared_proof) = left_proof.split_at_mut(1);
190
191                    let mut tree_hashes = self.tree.as_slice();
192                    let mut parent_position = pair_index;
193                    let mut parent_level_size = N / 2;
194
195                    for hash in shared_proof {
196                        // Line below is a more efficient branchless version of this:
197                        // let parent_other_position = if parent_position % 2 == 0 {
198                        //     parent_position + 1
199                        // } else {
200                        //     parent_position - 1
201                        // };
202                        let parent_other_position = parent_position ^ 1;
203
204                        // SAFETY: Statically guaranteed to be present by constructor
205                        let other_hash =
206                            unsafe { tree_hashes.get_unchecked(parent_other_position) };
207                        hash.write(*other_hash);
208                        (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
209
210                        parent_position /= 2;
211                        parent_level_size /= 2;
212                    }
213
214                    // SAFETY: Just initialized
215                    unsafe { left_proof.transpose().assume_init() }
216                };
217
218                let mut right_proof = left_proof;
219                right_proof[0] = left_hash;
220
221                [left_proof, right_proof]
222            })
223            // Special case for a single leaf tree to make sure proof is returned, even if it is
224            // empty
225            .chain({
226                let mut returned = false;
227
228                iter::from_fn(move || {
229                    if N == 1 && !returned {
230                        returned = true;
231                        Some([[0; OUT_LEN]; N.ilog2() as usize])
232                    } else {
233                        None
234                    }
235                })
236            });
237
238        ProofsIterator { iter, len: N }
239    }
240
241    /// Verify previously generated proof
242    #[inline]
243    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
244    pub fn verify(
245        root: &[u8; OUT_LEN],
246        proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
247        leaf_index: usize,
248        leaf: [u8; OUT_LEN],
249    ) -> bool
250    where
251        [(); N.ilog2() as usize]:,
252    {
253        if leaf_index >= N {
254            return false;
255        }
256
257        let mut computed_root = leaf;
258
259        let mut position = leaf_index;
260        for hash in proof {
261            computed_root = if position % 2 == 0 {
262                hash_pair(&computed_root, hash)
263            } else {
264                hash_pair(hash, &computed_root)
265            };
266
267            position /= 2;
268        }
269
270        root == &computed_root
271    }
272}
273
274struct ProofsIterator<Iter> {
275    iter: Iter,
276    len: usize,
277}
278
279impl<Iter> Iterator for ProofsIterator<Iter>
280where
281    Iter: Iterator,
282{
283    type Item = Iter::Item;
284
285    #[inline(always)]
286    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
287    fn next(&mut self) -> Option<Self::Item> {
288        let item = self.iter.next();
289        self.len = self.len.saturating_sub(1);
290        item
291    }
292
293    #[inline(always)]
294    fn size_hint(&self) -> (usize, Option<usize>) {
295        (self.len, Some(self.len))
296    }
297
298    #[inline(always)]
299    fn count(self) -> usize
300    where
301        Self: Sized,
302    {
303        self.len
304    }
305}
306
307impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
308where
309    Iter: Iterator,
310{
311    #[inline(always)]
312    fn len(&self) -> usize {
313        self.len
314    }
315}
316
317unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}