1use crate::{hash_pair, hash_pair_block, hash_pairs};
2use ab_blake3::{BLOCK_LEN, OUT_LEN};
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem;
7use core::mem::MaybeUninit;
8use core::num::NonZero;
9
10const BATCH_HASH_NUM_BLOCKS: usize = 16;
12const BATCH_HASH_NUM_LEAVES: usize = BATCH_HASH_NUM_BLOCKS * BLOCK_LEN / OUT_LEN;
14
15pub const fn compute_root_only_large_stack_size(n: usize) -> usize {
18 if n < BATCH_HASH_NUM_LEAVES {
21 return 1;
22 }
23
24 (n / BATCH_HASH_NUM_LEAVES).ilog2() as usize + 1
25}
26
27pub const fn ensure_supported_n(n: usize) -> usize {
32 assert!(
33 n.is_power_of_two(),
34 "Balanced Merkle Tree must have a number of leaves that is a power of 2"
35 );
36
37 assert!(
38 n > 1,
39 "This Balanced Merkle Tree must have more than one leaf"
40 );
41
42 0
43}
44
45#[derive(Debug)]
63pub struct BalancedMerkleTree<'a, const N: usize>
64where
65 [(); N - 1]:,
66{
67 leaves: &'a [[u8; OUT_LEN]],
68 tree: [[u8; OUT_LEN]; N - 1],
70}
71
72impl<'a, const N: usize> BalancedMerkleTree<'a, N>
75where
76 [(); N - 1]:,
77 [(); ensure_supported_n(N)]:,
78{
79 #[cfg_attr(
86 all(feature = "no-panic", not(target_arch = "riscv64")),
87 no_panic::no_panic
88 )]
89 pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
90 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
91
92 Self::init_internal(leaves, &mut tree);
93
94 Self {
95 leaves,
96 tree: unsafe { tree.transpose().assume_init() },
98 }
99 }
100
101 #[cfg_attr(
105 all(feature = "no-panic", not(target_arch = "riscv64")),
106 no_panic::no_panic
107 )]
108 pub fn new_in<'b>(
109 instance: &'b mut MaybeUninit<Self>,
110 leaves: &'a [[u8; OUT_LEN]; N],
111 ) -> &'b mut Self {
112 let instance_ptr = instance.as_mut_ptr();
113 unsafe {
115 (&raw mut (*instance_ptr).leaves).write(leaves);
116 }
117 let tree = {
118 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
120 unsafe {
122 tree_ptr
123 .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
124 .as_mut_unchecked()
125 }
126 };
127
128 Self::init_internal(leaves, tree);
129
130 unsafe { instance.assume_init_mut() }
132 }
133
134 #[cfg(feature = "alloc")]
137 pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
138 let mut instance = Box::<Self>::new_uninit();
139
140 Self::new_in(&mut instance, leaves);
141
142 unsafe { instance.assume_init() }
144 }
145
146 #[cfg_attr(
149 all(feature = "no-panic", not(target_arch = "riscv64")),
150 no_panic::no_panic
151 )]
152 fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
153 let mut tree_hashes = tree.as_mut_slice();
154 let mut level_hashes = leaves.as_slice();
155
156 while level_hashes.len() > 1 {
157 let num_pairs = level_hashes.len() / 2;
158 let parent_hashes;
159 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
162
163 if parent_hashes.len().is_multiple_of(BATCH_HASH_NUM_BLOCKS) {
164 let parent_hashes_chunks =
166 unsafe { parent_hashes.as_chunks_unchecked_mut::<BATCH_HASH_NUM_BLOCKS>() };
167 for (pairs, hashes) in level_hashes
168 .as_chunks::<BATCH_HASH_NUM_LEAVES>()
169 .0
170 .iter()
171 .zip(parent_hashes_chunks)
172 {
173 let hashes = unsafe {
177 mem::transmute::<
178 &mut [MaybeUninit<[u8; OUT_LEN]>; BATCH_HASH_NUM_BLOCKS],
179 &mut MaybeUninit<[[u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]>,
180 >(hashes)
181 };
182
183 hashes.write(hash_pairs(pairs));
186 }
187 } else {
188 for (pair, parent_hash) in level_hashes
189 .as_chunks()
190 .0
191 .iter()
192 .zip(parent_hashes.iter_mut())
193 {
194 let pair = unsafe {
196 mem::transmute::<&[[u8; OUT_LEN]; BLOCK_LEN / OUT_LEN], &[u8; BLOCK_LEN]>(
197 pair,
198 )
199 };
200 parent_hash.write(hash_pair_block(pair));
201 }
202 }
203
204 level_hashes = unsafe { parent_hashes.assume_init_ref() };
206 }
207 }
208
209 #[inline]
215 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
216 pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
217 where
218 [(); N.ilog2() as usize + 1]:,
219 [(); compute_root_only_large_stack_size(N)]:,
220 {
221 match N {
223 2 => {
224 let [root] = hash_pairs(leaves);
225
226 return root;
227 }
228 4 => {
229 let hashes = hash_pairs::<2, _>(leaves);
230 let [root] = hash_pairs(&hashes);
231
232 return root;
233 }
234 8 => {
235 let hashes = hash_pairs::<4, _>(leaves);
236 let hashes = hash_pairs::<2, _>(&hashes);
237 let [root] = hash_pairs(&hashes);
238
239 return root;
240 }
241 16 => {
242 let hashes = hash_pairs::<8, _>(leaves);
243 let hashes = hash_pairs::<4, _>(&hashes);
244 let hashes = hash_pairs::<2, _>(&hashes);
245 let [root] = hash_pairs(&hashes);
246
247 return root;
248 }
249 _ => {
250 assert!(N >= BATCH_HASH_NUM_LEAVES);
252 }
253 }
254
255 let mut stack =
259 [[[0u8; OUT_LEN]; BATCH_HASH_NUM_BLOCKS]; compute_root_only_large_stack_size(N)];
260
261 let mut parent_current = [[0u8; OUT_LEN]; BATCH_HASH_NUM_LEAVES];
264 for (num_chunks, chunk_leaves) in leaves
265 .as_chunks::<BATCH_HASH_NUM_LEAVES>()
266 .0
267 .iter()
268 .enumerate()
269 {
270 let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
271
272 let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(chunk_leaves);
273 current_half.copy_from_slice(¤t);
274
275 let lowest_active_levels = num_chunks.trailing_ones() as usize;
277 for parent in &mut stack[..lowest_active_levels] {
278 let parent_half = &mut parent_current[..BATCH_HASH_NUM_BLOCKS];
279 parent_half.copy_from_slice(parent);
280
281 let current = hash_pairs::<BATCH_HASH_NUM_BLOCKS, _>(&parent_current);
282
283 let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
284 current_half.copy_from_slice(¤t);
285 }
286
287 let current_half = &mut parent_current[BATCH_HASH_NUM_BLOCKS..];
288
289 stack[lowest_active_levels].copy_from_slice(current_half);
291 }
292
293 let hashes = &mut stack[compute_root_only_large_stack_size(N) - 1];
294 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 2 }, _>(hashes);
295 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 4 }, _>(&hashes);
296 let hashes = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 8 }, _>(&hashes);
297 let [root] = hash_pairs::<{ BATCH_HASH_NUM_BLOCKS / 16 }, _>(&hashes);
298
299 root
300 }
301
302 #[inline]
304 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
305 pub fn root(&self) -> [u8; OUT_LEN] {
306 *self
307 .tree
308 .last()
309 .or(self.leaves.last())
310 .expect("There is always at least one leaf hash; qed")
311 }
312
313 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
315 pub fn all_proofs(&self) -> ProofsIterator<'_, N>
316 where
317 [(); N.ilog2() as usize]:,
318 {
319 ProofsIterator {
320 leaves: self.leaves,
321 tree: &self.tree,
322 leaf_index: 0,
323 len: N,
324 }
325 }
326
327 #[inline]
329 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
330 pub fn verify(
331 root: &[u8; OUT_LEN],
332 proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
333 leaf_index: usize,
334 leaf: [u8; OUT_LEN],
335 ) -> bool
336 where
337 [(); N.ilog2() as usize]:,
338 {
339 if leaf_index >= N {
340 return false;
341 }
342
343 let mut computed_root = leaf;
344
345 let mut position = leaf_index;
346 for hash in proof {
347 computed_root = if position.is_multiple_of(2) {
348 hash_pair(&computed_root, hash)
349 } else {
350 hash_pair(hash, &computed_root)
351 };
352
353 position /= 2;
354 }
355
356 root == &computed_root
357 }
358}
359
360#[derive(Debug)]
362pub struct ProofsIterator<'a, const N: usize>
363where
364 [(); N.ilog2() as usize]:,
365 [(); N - 1]:,
366 [(); ensure_supported_n(N)]:,
367{
368 pub(super) leaves: &'a [[u8; OUT_LEN]],
369 pub(super) tree: &'a [[u8; OUT_LEN]; N - 1],
370 pub(super) leaf_index: usize,
371 pub(super) len: usize,
372}
373
374impl<'a, const N: usize> Iterator for ProofsIterator<'a, N>
375where
376 [(); N.ilog2() as usize]:,
377 [(); N - 1]:,
378 [(); ensure_supported_n(N)]:,
379{
380 type Item = [[u8; OUT_LEN]; N.ilog2() as usize];
381
382 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
383 fn next(&mut self) -> Option<Self::Item> {
384 if self.len == 0 {
385 return None;
386 }
387 self.len -= 1;
388
389 let index = self.leaf_index;
390 self.leaf_index += 1;
391
392 let sibling_index = index ^ 1;
399 let sibling_hash = *unsafe { self.leaves.get_unchecked(sibling_index) };
401
402 let mut proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
403 proof[0].write(sibling_hash);
404
405 let shared_proof = &mut proof[1..];
407
408 let mut tree_hashes = self.tree.as_slice();
409 let mut parent_position = index / 2;
410 let mut parent_level_size = N / 2;
411
412 for hash in shared_proof {
413 let parent_other_position = parent_position ^ 1;
414
415 let other_hash = unsafe { tree_hashes.get_unchecked(parent_other_position) };
417 hash.write(*other_hash);
418 tree_hashes = &tree_hashes[parent_level_size..];
419
420 parent_position /= 2;
421 parent_level_size /= 2;
422 }
423
424 Some(unsafe { proof.transpose().assume_init() })
426 }
427
428 #[inline(always)]
429 fn size_hint(&self) -> (usize, Option<usize>) {
430 (self.len, Some(self.len))
431 }
432
433 #[inline(always)]
434 fn count(self) -> usize {
435 self.len
436 }
437
438 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
439 fn last(mut self) -> Option<Self::Item> {
440 if self.len == 0 {
441 return None;
442 }
443 self.leaf_index = N - 1;
444 self.len = 1;
445 self.next()
446 }
447
448 #[inline(always)]
449 fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
450 let advance = n.min(self.len);
451 self.leaf_index += advance;
452 self.len -= advance;
453 NonZero::new(n - advance).map_or(Ok(()), Err)
454 }
455
456 #[inline(always)]
457 fn nth(&mut self, n: usize) -> Option<Self::Item> {
458 match self.advance_by(n) {
459 Ok(()) => self.next(),
460 Err(_) => None,
461 }
462 }
463}
464
465impl<'a, const N: usize> ExactSizeIterator for ProofsIterator<'a, N>
466where
467 [(); N.ilog2() as usize]:,
468 [(); N - 1]:,
469 [(); ensure_supported_n(N)]:,
470{
471 #[inline(always)]
472 fn len(&self) -> usize {
473 self.len
474 }
475}
476
477unsafe impl<'a, const N: usize> TrustedLen for ProofsIterator<'a, N>
479where
480 [(); N.ilog2() as usize]:,
481 [(); N - 1]:,
482 [(); ensure_supported_n(N)]:,
483{
484}