ab_merkle_tree/
balanced.rs

1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem::MaybeUninit;
7use core::num::NonZero;
8
9/// Ensuring only supported `N` can be specified for [`BalancedMerkleTree`].
10///
11/// This is essentially a workaround for the current Rust type system constraints that do not allow
12/// a nicer way to do the same thing at compile time.
13pub const fn ensure_supported_n(n: usize) -> usize {
14    assert!(
15        n.is_power_of_two(),
16        "Balanced Merkle Tree must have a number of leaves that is a power of 2"
17    );
18
19    assert!(
20        n > 1,
21        "This Balanced Merkle Tree must have more than one leaf"
22    );
23
24    0
25}
26
27/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
28/// generic parameter.
29///
30/// This can be considered a general case of [`UnbalancedMerkleTree`]. The root and proofs are
31/// identical for both in case the number of leaves is a power of two. For the number of leaves that
32/// is a power of two [`UnbalancedMerkleTree`] is useful when a single proof needs to be generated
33/// and the number of leaves is very large (it can generate proofs with very little RAM usage
34/// compared to this version).
35///
36/// [`UnbalancedMerkleTree`]: crate::unbalanced::UnbalancedMerkleTree
37///
38/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
39/// the elements need to be generated and the whole tree easily fits into memory. It can also be
40/// constructed and proofs can be generated efficiently without heap allocations.
41///
42/// With all parameters of the tree known statically, it results in the most efficient version of
43/// the code being generated for a given set of parameters.
44#[derive(Debug)]
45pub struct BalancedMerkleTree<'a, const N: usize>
46where
47    [(); N - 1]:,
48{
49    leaves: &'a [[u8; OUT_LEN]],
50    // This tree doesn't include leaves because we have them in `leaves` field
51    tree: [[u8; OUT_LEN]; N - 1],
52}
53
54// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
55//  https://github.com/BLAKE3-team/BLAKE3/issues/478
56impl<'a, const N: usize> BalancedMerkleTree<'a, N>
57where
58    [(); N - 1]:,
59    [(); ensure_supported_n(N)]:,
60{
61    /// Create a new tree from a fixed set of elements.
62    ///
63    /// The data structure is statically allocated and might be too large to fit on the stack!
64    /// If that is the case, use `new_boxed()` method.
65    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
66    pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
67        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
68
69        Self::init_internal(leaves, &mut tree);
70
71        Self {
72            leaves,
73            // SAFETY: Statically guaranteed for all elements to be initialized
74            tree: unsafe { tree.transpose().assume_init() },
75        }
76    }
77
78    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
79    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
80    pub fn new_in<'b>(
81        instance: &'b mut MaybeUninit<Self>,
82        leaves: &'a [[u8; OUT_LEN]; N],
83    ) -> &'b mut Self {
84        let instance_ptr = instance.as_mut_ptr();
85        // SAFETY: Valid and correctly aligned non-null pointer
86        unsafe {
87            (&raw mut (*instance_ptr).leaves).write(leaves);
88        }
89        let tree = {
90            // SAFETY: Valid and correctly aligned non-null pointer
91            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
92            // SAFETY: Allocated and correctly aligned uninitialized data
93            unsafe {
94                tree_ptr
95                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
96                    .as_mut_unchecked()
97            }
98        };
99
100        Self::init_internal(leaves, tree);
101
102        // SAFETY: Initialized field by field above
103        unsafe { instance.assume_init_mut() }
104    }
105
106    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
107    /// for large trees
108    #[cfg(feature = "alloc")]
109    pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
110        let mut instance = Box::<Self>::new_uninit();
111
112        Self::new_in(&mut instance, leaves);
113
114        // SAFETY: Initialized by constructor above
115        unsafe { instance.assume_init() }
116    }
117
118    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
119    fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
120        let mut tree_hashes = tree.as_mut_slice();
121        let mut level_hashes = leaves.as_slice();
122
123        let mut pair = [0u8; OUT_LEN * 2];
124        while level_hashes.len() > 1 {
125            let num_pairs = level_hashes.len() / 2;
126            let parent_hashes;
127            // SAFETY: The size of the tree is statically known to match the number of leaves and
128            // levels of hashes
129            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
130
131            for ([left_hash, right_hash], parent_hash) in
132                level_hashes.array_chunks().zip(parent_hashes.iter_mut())
133            {
134                pair[..OUT_LEN].copy_from_slice(left_hash);
135                pair[OUT_LEN..].copy_from_slice(right_hash);
136
137                parent_hash.write(hash_pair(left_hash, right_hash));
138            }
139
140            // SAFETY: Just initialized
141            level_hashes = unsafe { parent_hashes.assume_init_ref() };
142        }
143    }
144
145    // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
146    /// Compute Merkle Tree root.
147    ///
148    /// This is functionally equivalent to creating an instance first and calling [`Self::root()`]
149    /// method, but is faster and avoids heap allocation when root is the only thing that is needed.
150    #[inline]
151    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
152    pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
153    where
154        [(); N.ilog2() as usize + 1]:,
155    {
156        // Stack of intermediate nodes per tree level
157        let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
158
159        for (num_leaves, &hash) in leaves.iter().enumerate() {
160            let mut current = hash;
161
162            // Every bit set to `1` corresponds to an active Merkle Tree level
163            let lowest_active_levels = num_leaves.trailing_ones() as usize;
164            for item in stack.iter().take(lowest_active_levels) {
165                current = hash_pair(item, &current);
166            }
167
168            // Place the current hash at the first inactive level
169            stack[lowest_active_levels] = current;
170        }
171
172        stack[N.ilog2() as usize]
173    }
174
175    /// Get the root of Merkle Tree
176    #[inline]
177    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
178    pub fn root(&self) -> [u8; OUT_LEN] {
179        *self
180            .tree
181            .last()
182            .or(self.leaves.last())
183            .expect("There is always at least one leaf hash; qed")
184    }
185
186    /// Iterator over proofs in the same order as provided leaf hashes
187    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
188    pub fn all_proofs(
189        &self,
190    ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
191    where
192        [(); N.ilog2() as usize]:,
193    {
194        let iter = self.leaves.array_chunks().enumerate().flat_map(
195            |(pair_index, &[left_hash, right_hash])| {
196                let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
197                left_proof[0].write(right_hash);
198
199                let left_proof = {
200                    let (_, shared_proof) = left_proof.split_at_mut(1);
201
202                    let mut tree_hashes = self.tree.as_slice();
203                    let mut parent_position = pair_index;
204                    let mut parent_level_size = N / 2;
205
206                    for hash in shared_proof {
207                        // Line below is a more efficient branchless version of this:
208                        // let parent_other_position = if parent_position % 2 == 0 {
209                        //     parent_position + 1
210                        // } else {
211                        //     parent_position - 1
212                        // };
213                        let parent_other_position = parent_position ^ 1;
214
215                        // SAFETY: Statically guaranteed to be present by constructor
216                        let other_hash =
217                            unsafe { tree_hashes.get_unchecked(parent_other_position) };
218                        hash.write(*other_hash);
219                        (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
220
221                        parent_position /= 2;
222                        parent_level_size /= 2;
223                    }
224
225                    // SAFETY: Just initialized
226                    unsafe { left_proof.transpose().assume_init() }
227                };
228
229                let mut right_proof = left_proof;
230                right_proof[0] = left_hash;
231
232                [left_proof, right_proof]
233            },
234        );
235
236        ProofsIterator { iter, len: N }
237    }
238
239    /// Verify previously generated proof
240    #[inline]
241    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
242    pub fn verify(
243        root: &[u8; OUT_LEN],
244        proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
245        leaf_index: usize,
246        leaf: [u8; OUT_LEN],
247    ) -> bool
248    where
249        [(); N.ilog2() as usize]:,
250    {
251        if leaf_index >= N {
252            return false;
253        }
254
255        let mut computed_root = leaf;
256
257        let mut position = leaf_index;
258        for hash in proof {
259            computed_root = if position.is_multiple_of(2) {
260                hash_pair(&computed_root, hash)
261            } else {
262                hash_pair(hash, &computed_root)
263            };
264
265            position /= 2;
266        }
267
268        root == &computed_root
269    }
270}
271
272struct ProofsIterator<Iter> {
273    iter: Iter,
274    len: usize,
275}
276
277impl<Iter> Iterator for ProofsIterator<Iter>
278where
279    Iter: Iterator,
280{
281    type Item = Iter::Item;
282
283    #[inline(always)]
284    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
285    fn next(&mut self) -> Option<Self::Item> {
286        let item = self.iter.next();
287        self.len = self.len.saturating_sub(1);
288        item
289    }
290
291    #[inline(always)]
292    fn size_hint(&self) -> (usize, Option<usize>) {
293        (self.len, Some(self.len))
294    }
295
296    #[inline(always)]
297    fn count(self) -> usize
298    where
299        Self: Sized,
300    {
301        self.len
302    }
303
304    #[inline(always)]
305    fn last(self) -> Option<Self::Item>
306    where
307        Self: Sized,
308    {
309        self.iter.last()
310    }
311
312    #[inline(always)]
313    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
314        self.len = self.len.saturating_sub(n);
315        self.iter.advance_by(n)
316    }
317
318    #[inline(always)]
319    fn nth(&mut self, n: usize) -> Option<Self::Item> {
320        self.len = self.len.saturating_sub(n.saturating_add(1));
321        self.iter.nth(n)
322    }
323}
324
325impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
326where
327    Iter: Iterator,
328{
329    #[inline(always)]
330    fn len(&self) -> usize {
331        self.len
332    }
333}
334
335unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}