ab_merkle_tree/
balanced.rs1use crate::{hash_pair, hash_pair_block, hash_pairs};
2use ab_blake3::{BLOCK_LEN, OUT_LEN};
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem;
7use core::mem::MaybeUninit;
8use core::num::NonZero;
9
10const BATCH_HASH_NUM_BLOCKS: usize = 16;
12const BATCH_HASH_NUM_LEAVES: usize = BATCH_HASH_NUM_BLOCKS * BLOCK_LEN / OUT_LEN;
14
15pub const fn compute_root_only_large_stack_size(n: usize) -> usize {
18 if n < BATCH_HASH_NUM_LEAVES {
21 return 1;
22 }
23
24 (n / BATCH_HASH_NUM_LEAVES).ilog2() as usize + 1
25}
26
27pub const fn ensure_supported_n(n: usize) -> usize {
32 assert!(
33 n.is_power_of_two(),
34 "Balanced Merkle Tree must have a number of leaves that is a power of 2"
35 );
36
37 assert!(
38 n > 1,
39 "This Balanced Merkle Tree must have more than one leaf"
40 );
41
42 0
43}
44
45#[derive(Debug)]
63pub struct BalancedMerkleTree<'a, const N: usize>
64where
65 [(); N - 1]:,
66{
67 leaves: &'a [[u8; OUT_LEN]],
68 tree: [[u8; OUT_LEN]; N - 1],
70}
71
72impl<'a, const N: usize> BalancedMerkleTree<'a, N>
75where
76 [(); N - 1]:,
77 [(); ensure_supported_n(N)]:,
78{
79 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
84 pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
85 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
86
87 Self::init_internal(leaves, &mut tree);
88
89 Self {
90 leaves,
91 tree: unsafe { tree.transpose().assume_init() },
93 }
94 }
95
96 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
98 pub fn new_in<'b>(
99 instance: &'b mut MaybeUninit<Self>,
100 leaves: &'a [[u8; OUT_LEN]; N],
101 ) -> &'b mut Self {
102 let instance_ptr = instance.as_mut_ptr();
103 unsafe {
105 (&raw mut (*instance_ptr).leaves).write(leaves);
106 }
107 let tree = {
108 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
110 unsafe {
112 tree_ptr
113 .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
114 .as_mut_unchecked()
115 }
116 };
117
118 Self::init_internal(leaves, tree);
119
120 unsafe { instance.assume_init_mut() }
122 }
123
124 #[cfg(feature = "alloc")]
127 pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
128 let mut instance = Box::<Self>::new_uninit();
129
130 Self::new_in(&mut instance, leaves);
131
132 unsafe { instance.assume_init() }
134 }
135
136 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
137 fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
138 let mut tree_hashes = tree.as_mut_slice();
139 let mut level_hashes = leaves.as_slice();
140
141 while level_hashes.len() > 1 {
142 let num_pairs = level_hashes.len() / 2;
143 let parent_hashes;
144 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
147
148 if parent_hashes.len().is_multiple_of(BATCH_HASH_NUM_BLOCKS) {
149 let parent_hashes_chunks =
151 unsafe { parent_hashes.as_chunks_unchecked_mut::<BATCH_HASH_NUM_BLOCKS>() };
152 for (pairs, hashes) in level_hashes
153 .as_chunks::<BATCH_HASH_NUM_LEAVES>()
154 .0
155 .iter()
156 .zip(parent_hashes_chunks)
157 {
158 let hashes = unsafe {
162 mem::transmute::<
163 &mut [MaybeUninit<[u8; OUT_LEN]>; BATCH_HASH_NUM_BLOCKS],
164 &mut MaybeUninit<[[u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]>,
165 >(hashes)
166 };
167
168 hashes.write(hash_pairs(pairs));
171 }
172 } else {
173 for (pair, parent_hash) in level_hashes
174 .as_chunks()
175 .0
176 .iter()
177 .zip(parent_hashes.iter_mut())
178 {
179 let pair = unsafe {
181 mem::transmute::<&[[u8; OUT_LEN]; BLOCK_LEN / OUT_LEN], &[u8; BLOCK_LEN]>(
182 pair,
183 )
184 };
185 parent_hash.write(hash_pair_block(pair));
186 }
187 }
188
189 level_hashes = unsafe { parent_hashes.assume_init_ref() };
191 }
192 }
193
194 #[inline]
200 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
201 pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
202 where
203 [(); N.ilog2() as usize + 1]:,
204 [(); compute_root_only_large_stack_size(N)]:,
205 {
206 match N {
208 2 => {
209 let [root] = hash_pairs(leaves);
210
211 return root;
212 }
213 4 => {
214 let hashes = hash_pairs::<2, _>(leaves);
215 let [root] = hash_pairs(&hashes);
216
217 return root;
218 }
219 8 => {
220 let hashes = hash_pairs::<4, _>(leaves);
221 let hashes = hash_pairs::<2, _>(&hashes);
222 let [root] = hash_pairs(&hashes);
223
224 return root;
225 }
226 16 => {
227 let hashes = hash_pairs::<8, _>(leaves);
228 let hashes = hash_pairs::<4, _>(&hashes);
229 let hashes = hash_pairs::<2, _>(&hashes);
230 let [root] = hash_pairs(&hashes);
231
232 return root;
233 }
234 _ => {
235 assert!(N >= BATCH_HASH_NUM_LEAVES);
237 }
238 }
239
240 let mut stack =
244 [[[0u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]; compute_root_only_large_stack_size(N)];
245
246 let mut parent_current = [[0u8; OUT_LEN]; BATCH_HASH_NUM_LEAVES];
249 for (num_chunks, chunk_leaves) in leaves
250 .as_chunks::<BATCH_HASH_NUM_LEAVES>()
251 .0
252 .iter()
253 .enumerate()
254 {
255 let (_parent_half, current_half) = parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
256
257 let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(chunk_leaves);
258 current_half.copy_from_slice(¤t);
259
260 let lowest_active_levels = num_chunks.trailing_ones() as usize;
262 for parent in &mut stack[..lowest_active_levels] {
263 let (parent_half, _current_half) =
264 parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
265 parent_half.copy_from_slice(parent);
266
267 let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(&parent_current);
268
269 let (_parent_half, current_half) =
270 parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
271 current_half.copy_from_slice(¤t);
272 }
273
274 let (_parent_half, current_half) = parent_current.split_at_mut(BATCH_HASH_NUM_BLOCKS);
275
276 stack[lowest_active_levels].copy_from_slice(current_half);
278 }
279
280 let hashes = &mut stack[compute_root_only_large_stack_size(N) - 1];
281 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 2 }, _>(hashes);
282 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 4 }, _>(&hashes);
283 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 8 }, _>(&hashes);
284 let [root] = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 16 }, _>(&hashes);
285
286 root
287 }
288
289 #[inline]
291 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
292 pub fn root(&self) -> [u8; OUT_LEN] {
293 *self
294 .tree
295 .last()
296 .or(self.leaves.last())
297 .expect("There is always at least one leaf hash; qed")
298 }
299
300 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
302 pub fn all_proofs(
303 &self,
304 ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
305 where
306 [(); N.ilog2() as usize]:,
307 {
308 let iter = self.leaves.as_chunks().0.iter().enumerate().flat_map(
309 |(pair_index, &[left_hash, right_hash])| {
310 let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
311 left_proof[0].write(right_hash);
312
313 let left_proof = {
314 let (_, shared_proof) = left_proof.split_at_mut(1);
315
316 let mut tree_hashes = self.tree.as_slice();
317 let mut parent_position = pair_index;
318 let mut parent_level_size = N / 2;
319
320 for hash in shared_proof {
321 let parent_other_position = parent_position ^ 1;
328
329 let other_hash =
331 unsafe { tree_hashes.get_unchecked(parent_other_position) };
332 hash.write(*other_hash);
333 (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
334
335 parent_position /= 2;
336 parent_level_size /= 2;
337 }
338
339 unsafe { left_proof.transpose().assume_init() }
341 };
342
343 let mut right_proof = left_proof;
344 right_proof[0] = left_hash;
345
346 [left_proof, right_proof]
347 },
348 );
349
350 ProofsIterator { iter, len: N }
351 }
352
353 #[inline]
355 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
356 pub fn verify(
357 root: &[u8; OUT_LEN],
358 proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
359 leaf_index: usize,
360 leaf: [u8; OUT_LEN],
361 ) -> bool
362 where
363 [(); N.ilog2() as usize]:,
364 {
365 if leaf_index >= N {
366 return false;
367 }
368
369 let mut computed_root = leaf;
370
371 let mut position = leaf_index;
372 for hash in proof {
373 computed_root = if position.is_multiple_of(2) {
374 hash_pair(&computed_root, hash)
375 } else {
376 hash_pair(hash, &computed_root)
377 };
378
379 position /= 2;
380 }
381
382 root == &computed_root
383 }
384}
385
386struct ProofsIterator<Iter> {
387 iter: Iter,
388 len: usize,
389}
390
391impl<Iter> Iterator for ProofsIterator<Iter>
392where
393 Iter: Iterator,
394{
395 type Item = Iter::Item;
396
397 #[inline(always)]
398 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
399 fn next(&mut self) -> Option<Self::Item> {
400 let item = self.iter.next();
401 self.len = self.len.saturating_sub(1);
402 item
403 }
404
405 #[inline(always)]
406 fn size_hint(&self) -> (usize, Option<usize>) {
407 (self.len, Some(self.len))
408 }
409
410 #[inline(always)]
411 fn count(self) -> usize
412 where
413 Self: Sized,
414 {
415 self.len
416 }
417
418 #[inline(always)]
419 fn last(self) -> Option<Self::Item>
420 where
421 Self: Sized,
422 {
423 self.iter.last()
424 }
425
426 #[inline(always)]
427 fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
428 self.len = self.len.saturating_sub(n);
429 self.iter.advance_by(n)
430 }
431
432 #[inline(always)]
433 fn nth(&mut self, n: usize) -> Option<Self::Item> {
434 self.len = self.len.saturating_sub(n.saturating_add(1));
435 self.iter.nth(n)
436 }
437}
438
439impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
440where
441 Iter: Iterator,
442{
443 #[inline(always)]
444 fn len(&self) -> usize {
445 self.len
446 }
447}
448
449unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}