ab_merkle_tree/
balanced_hashed.rs

1#[cfg(feature = "alloc")]
2extern crate alloc;
3
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6use blake3::OUT_LEN;
7use core::iter::TrustedLen;
8use core::mem;
9use core::mem::MaybeUninit;
10
11/// Number of hashes, including root and excluding already provided leaf hashes
12#[inline(always)]
13pub const fn num_hashes(num_leaves_log_2: u32) -> usize {
14    2_usize.pow(num_leaves_log_2) - 1
15}
16
17/// Number of leaves in a tree
18#[inline(always)]
19pub const fn num_leaves(num_leaves_log_2: u32) -> usize {
20    2_usize.pow(num_leaves_log_2)
21}
22
23/// Merkle Tree variant that has hash-sized leaves and is fully balanced according to configured
24/// generic parameter.
25///
26/// This Merkle Tree implementation is best suited for use cases when proofs for all (or most) of
27/// the elements need to be generated and the whole tree easily fits into memory. It can also be
28/// constructed and proofs can be generated efficiently without heap allocations.
29///
30/// With all parameters of the tree known statically, it results in the most efficient version of
31/// the code being generated for a given set of parameters.
32///
33/// `NUM_LEAVES_LOG_2` is base-2 logarithm of the number of leaves in a tree.
34#[derive(Debug)]
35pub struct BalancedHashedMerkleTree<'a, const NUM_LEAVES_LOG_2: u32>
36where
37    [(); num_hashes(NUM_LEAVES_LOG_2)]:,
38{
39    leaf_hashes: &'a [[u8; OUT_LEN]],
40    // This tree doesn't include leaves because we know the size
41    tree: [[u8; OUT_LEN]; num_hashes(NUM_LEAVES_LOG_2)],
42}
43
44// TODO: Replace hashing individual records with blake3 and building tree manually with building the
45//  tree using blake3 itself, such that the root is the same as hashing data with blake3, see
46//  https://github.com/BLAKE3-team/BLAKE3/issues/470 for details. Two options are:
47//  expand values to 1024 bytes or modify blake3 to use 32-byte chunk size (at which point it'll
48//  unfortunately stop being blake3)
49impl<'a, const NUM_LEAVES_LOG_2: u32> BalancedHashedMerkleTree<'a, NUM_LEAVES_LOG_2>
50where
51    [(); num_hashes(NUM_LEAVES_LOG_2)]:,
52{
53    /// Create a new tree from a fixed set of elements.
54    ///
55    /// The data structure is statically allocated and might be too large to fit on the stack!
56    /// If that is the case, use `new_boxed()` method.
57    pub fn new(leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)]) -> Self {
58        let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); num_hashes(NUM_LEAVES_LOG_2)];
59
60        Self::init_internal(leaf_hashes, &mut tree);
61
62        Self {
63            leaf_hashes,
64            // SAFETY: Statically guaranteed for all elements to be initialized
65            tree: unsafe { tree.transpose().assume_init() },
66        }
67    }
68
69    /// Like [`Self::new()`], but used pre-allocated memory for instantiation
70    pub fn new_in<'b>(
71        instance: &'b mut MaybeUninit<Self>,
72        leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)],
73    ) -> &'b mut Self {
74        let instance_ptr = instance.as_mut_ptr();
75        // SAFETY: Valid and correctly aligned non-null pointer
76        unsafe {
77            (&raw mut (*instance_ptr).leaf_hashes).write(leaf_hashes);
78        }
79        let tree = {
80            // SAFETY: Valid and correctly aligned non-null pointer
81            let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
82            // SAFETY: Allocated and correctly aligned uninitialized data
83            unsafe {
84                tree_ptr
85                    .cast::<[MaybeUninit<[u8; OUT_LEN]>; num_hashes(NUM_LEAVES_LOG_2)]>()
86                    .as_mut_unchecked()
87            }
88        };
89
90        Self::init_internal(leaf_hashes, tree);
91
92        // SAFETY: Initialized field by field above
93        unsafe { instance.assume_init_mut() }
94    }
95
96    /// Like [`Self::new()`], but creates heap-allocated instance, avoiding excessive stack usage
97    /// for large trees
98    #[cfg(feature = "alloc")]
99    pub fn new_boxed(leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)]) -> Box<Self> {
100        let mut instance = Box::<Self>::new_uninit();
101
102        Self::new_in(&mut instance, leaf_hashes);
103
104        // SAFETY: Initialized by constructor above
105        unsafe { instance.assume_init() }
106    }
107
108    fn init_internal(
109        leaf_hashes: &[[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)],
110        tree: &mut [MaybeUninit<[u8; OUT_LEN]>; num_hashes(NUM_LEAVES_LOG_2)],
111    ) {
112        let mut tree_hashes = tree.as_mut_slice();
113        let mut level_hashes = leaf_hashes.as_slice();
114
115        let mut pair = [0u8; OUT_LEN * 2];
116        while level_hashes.len() > 1 {
117            let num_pairs = level_hashes.len() / 2;
118            let parent_hashes;
119            // SAFETY: The size of the tree is statically known to match the number of leaves and
120            // levels of hashes
121            (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
122
123            for pair_index in 0..num_pairs {
124                // SAFETY: Entry is statically known to be present
125                let left_hash = unsafe { level_hashes.get_unchecked(pair_index * 2) };
126                // SAFETY: Entry is statically known to be present
127                let right_hash = unsafe { level_hashes.get_unchecked(pair_index * 2 + 1) };
128                // SAFETY: Entry is statically known to be present
129                let parent_hash = unsafe { parent_hashes.get_unchecked_mut(pair_index) };
130
131                pair[..OUT_LEN].copy_from_slice(left_hash);
132                pair[OUT_LEN..].copy_from_slice(right_hash);
133
134                parent_hash.write(*blake3::hash(&pair).as_bytes());
135            }
136
137            // SAFETY: Just initialized
138            level_hashes = unsafe { parent_hashes.assume_init_ref() };
139        }
140    }
141
142    /// Get the root of Merkle Tree.
143    ///
144    /// In case a tree contains a single leaf hash, that leaf hash is returned.
145    #[inline]
146    pub fn root(&self) -> [u8; OUT_LEN] {
147        *self
148            .tree
149            .last()
150            .or(self.leaf_hashes.last())
151            .expect("There is always at least one leaf hash; qed")
152    }
153
154    /// Iterator over proofs in the same order as provided leaf hashes
155    pub fn all_proofs(
156        &self,
157    ) -> impl ExactSizeIterator<Item = [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize]> + TrustedLen
158    where
159        [(); OUT_LEN * NUM_LEAVES_LOG_2 as usize]:,
160    {
161        let iter = self.leaf_hashes.array_chunks().enumerate().flat_map(
162            |(pair_index, &[left_hash, right_hash])| {
163                let mut left_proof =
164                    [MaybeUninit::<[u8; OUT_LEN]>::uninit(); NUM_LEAVES_LOG_2 as usize];
165                left_proof[0].write(right_hash);
166
167                let left_proof = {
168                    let (_, shared_proof) = left_proof.split_at_mut(1);
169
170                    let mut tree_hashes = self.tree.as_slice();
171                    let mut parent_position = pair_index;
172                    let mut parent_level_size = num_leaves(NUM_LEAVES_LOG_2) / 2;
173
174                    for hash in shared_proof {
175                        let parent_other_position = if parent_position % 2 == 0 {
176                            parent_position + 1
177                        } else {
178                            parent_position - 1
179                        };
180                        // SAFETY: Statically guaranteed to be present by constructor
181                        let other_hash =
182                            unsafe { tree_hashes.get_unchecked(parent_other_position) };
183                        hash.write(*other_hash);
184                        (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
185
186                        parent_position /= 2;
187                        parent_level_size /= 2;
188                    }
189
190                    // SAFETY: Just initialized
191                    unsafe { left_proof.transpose().assume_init() }
192                };
193
194                let mut right_proof = left_proof;
195                right_proof[0] = left_hash;
196
197                // TODO: Should have been just `transmute`, but compiler has a bug:
198                //  https://github.com/rust-lang/rust/issues/61956
199                // SAFETY: From and to have the same size and alignment
200                let left_proof = unsafe {
201                    mem::transmute_copy::<
202                        [[u8; OUT_LEN]; NUM_LEAVES_LOG_2 as usize],
203                        [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
204                    >(&left_proof)
205                };
206                let right_proof = unsafe {
207                    mem::transmute_copy::<
208                        [[u8; OUT_LEN]; NUM_LEAVES_LOG_2 as usize],
209                        [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
210                    >(&right_proof)
211                };
212                [left_proof, right_proof]
213            },
214        );
215
216        ProofsIterator {
217            iter,
218            len: num_leaves(NUM_LEAVES_LOG_2),
219        }
220    }
221
222    /// Verify previously generated proof
223    #[inline]
224    pub fn verify(
225        root: &[u8; OUT_LEN],
226        proof: &[u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
227        leaf_index: usize,
228        leaf_hash: [u8; OUT_LEN],
229    ) -> bool
230    where
231        [(); OUT_LEN * NUM_LEAVES_LOG_2 as usize]:,
232    {
233        if leaf_index >= num_leaves(NUM_LEAVES_LOG_2) {
234            return false;
235        }
236
237        let mut computed_root = leaf_hash;
238
239        let mut position = leaf_index;
240        let mut pair = [0u8; OUT_LEN * 2];
241        for hash in proof.array_chunks::<OUT_LEN>() {
242            if position % 2 == 0 {
243                pair[..OUT_LEN].copy_from_slice(&computed_root);
244                pair[OUT_LEN..].copy_from_slice(hash);
245            } else {
246                pair[..OUT_LEN].copy_from_slice(hash);
247                pair[OUT_LEN..].copy_from_slice(&computed_root);
248            }
249
250            position /= 2;
251            computed_root = *blake3::hash(&pair).as_bytes();
252        }
253
254        root == &computed_root
255    }
256}
257
258struct ProofsIterator<Iter> {
259    iter: Iter,
260    len: usize,
261}
262
263impl<Iter> Iterator for ProofsIterator<Iter>
264where
265    Iter: Iterator,
266{
267    type Item = Iter::Item;
268
269    #[inline(always)]
270    fn next(&mut self) -> Option<Self::Item> {
271        self.iter.next()
272    }
273
274    #[inline(always)]
275    fn size_hint(&self) -> (usize, Option<usize>) {
276        (self.len, Some(self.len))
277    }
278
279    #[inline(always)]
280    fn count(self) -> usize
281    where
282        Self: Sized,
283    {
284        self.len
285    }
286}
287
288impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
289where
290    Iter: Iterator,
291{
292    #[inline(always)]
293    fn len(&self) -> usize {
294        self.len
295    }
296}
297
298unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}