ab_merkle_tree/
unbalanced.rs1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30impl UnbalancedMerkleTree {
36 #[inline]
43 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44 pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45 leaves: Iter,
46 ) -> Option<[u8; OUT_LEN]>
47 where
48 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
49 Item: Into<[u8; OUT_LEN]>,
50 Iter: IntoIterator<Item = Item> + 'a,
51 {
52 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
54 let mut num_leaves = 0_u64;
55
56 for hash in leaves {
57 if num_leaves >= MAX_N {
60 return None;
61 }
62
63 let mut current = hash.into();
64
65 let lowest_active_levels = num_leaves.trailing_ones() as usize;
67 for item in stack.iter().take(lowest_active_levels) {
68 current = hash_pair(item, ¤t);
69 }
70
71 stack[lowest_active_levels] = current;
73 num_leaves += 1;
74 }
75
76 if num_leaves == 0 {
77 return None;
79 }
80
81 let mut stack_bits = num_leaves;
82
83 {
84 let lowest_active_level = stack_bits.trailing_zeros() as usize;
85 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
88 stack_bits &= !(1 << lowest_active_level);
90 }
91
92 loop {
94 let lowest_active_level = stack_bits.trailing_zeros() as usize;
95
96 if lowest_active_level == u64::BITS as usize {
97 break;
98 }
99
100 stack_bits &= !(1 << lowest_active_level);
102
103 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
105
106 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
107 }
108
109 Some(stack[0])
110 }
111
112 #[inline]
119 #[cfg(feature = "alloc")]
120 pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
121 leaves: Iter,
122 target_index: usize,
123 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
124 where
125 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
126 Item: Into<[u8; OUT_LEN]>,
127 Iter: IntoIterator<Item = Item> + 'a,
128 {
129 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
131 let mut proof = unsafe {
133 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize]>::new_uninit().assume_init()
134 };
135
136 let (root, proof_length) =
137 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
138
139 let proof_capacity = proof.len();
140 let proof = Box::into_raw(proof);
141 let proof = unsafe {
144 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
145 };
146
147 Some((root, proof))
148 }
149
150 #[inline]
157 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
158 pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
159 leaves: Iter,
160 target_index: usize,
161 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
162 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
163 where
164 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
165 Item: Into<[u8; OUT_LEN]>,
166 Iter: IntoIterator<Item = Item> + 'a,
167 {
168 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
170
171 let (root, proof_length) =
172 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
173 let proof = unsafe { proof.get_unchecked_mut(..proof_length).assume_init_mut() };
175
176 Some((root, proof))
177 }
178
179 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
180 fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
181 leaves: Iter,
182 target_index: usize,
183 stack: &mut [[u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1],
184 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
185 ) -> Option<([u8; OUT_LEN], usize)>
186 where
187 Item: Into<[u8; OUT_LEN]>,
188 Iter: IntoIterator<Item = Item> + 'a,
189 {
190 let mut proof_length = 0;
191 let mut num_leaves = 0_u64;
192
193 let mut current_target_level = None;
194 let mut position = target_index;
195
196 for (current_index, hash) in leaves.into_iter().enumerate() {
197 if num_leaves >= MAX_N {
200 return None;
201 }
202
203 let mut current = hash.into();
204
205 let lowest_active_levels = num_leaves.trailing_ones() as usize;
207
208 if current_index == target_index {
209 for item in stack.iter().take(lowest_active_levels) {
210 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
213 proof_length += 1;
214
215 current = hash_pair(item, ¤t);
216
217 position /= 2;
219 }
220
221 current_target_level = Some(lowest_active_levels);
222 } else {
223 for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
224 if current_target_level == Some(level) {
225 unsafe { proof.get_unchecked_mut(proof_length) }.write(
227 if position.is_multiple_of(2) {
228 current
229 } else {
230 *item
231 },
232 );
233 proof_length += 1;
234
235 current_target_level = Some(level + 1);
236
237 position /= 2;
239 }
240
241 current = hash_pair(item, ¤t);
242 }
243 }
244
245 stack[lowest_active_levels] = current;
247 num_leaves += 1;
248 }
249
250 if target_index >= num_leaves as usize {
252 return None;
254 }
255
256 let Some(current_target_level) = current_target_level else {
257 return None;
259 };
260
261 let mut stack_bits = num_leaves;
262
263 {
264 let lowest_active_level = stack_bits.trailing_zeros() as usize;
265 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
268 stack_bits &= !(1 << lowest_active_level);
270 }
271
272 let mut merged_peaks = false;
275 loop {
276 let lowest_active_level = stack_bits.trailing_zeros() as usize;
277
278 if lowest_active_level == u64::BITS as usize {
279 break;
280 }
281
282 stack_bits &= !(1 << lowest_active_level);
284
285 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
287
288 if lowest_active_level > current_target_level
289 || (lowest_active_level == current_target_level
290 && !position.is_multiple_of(2)
291 && !merged_peaks)
292 {
293 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
295 proof_length += 1;
296 merged_peaks = false;
297 } else if lowest_active_level == current_target_level {
298 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
300 proof_length += 1;
301 merged_peaks = false;
302 } else {
303 merged_peaks = true;
305 }
306
307 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
309
310 position /= 2;
311 }
312
313 Some((stack[0], proof_length))
314 }
315
316 #[inline]
318 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
319 pub fn verify(
320 root: &[u8; OUT_LEN],
321 proof: &[[u8; OUT_LEN]],
322 leaf_index: u64,
323 leaf: [u8; OUT_LEN],
324 num_leaves: u64,
325 ) -> bool {
326 if leaf_index >= num_leaves {
327 return false;
328 }
329
330 let mut current = leaf;
331 let mut position = leaf_index;
332 let mut proof_pos = 0;
333 let mut level_size = num_leaves;
334
335 while level_size > 1 {
337 let is_left = position.is_multiple_of(2);
338 let is_last = position == level_size - 1;
339
340 if is_left && !is_last {
341 if proof_pos >= proof.len() {
343 return false;
345 }
346 current = hash_pair(¤t, &proof[proof_pos]);
347 proof_pos += 1;
348 } else if !is_left {
349 if proof_pos >= proof.len() {
351 return false;
353 }
354 current = hash_pair(&proof[proof_pos], ¤t);
355 proof_pos += 1;
356 } else {
357 }
359
360 position /= 2;
361 level_size = level_size.div_ceil(2);
363 }
364
365 proof_pos == proof.len() && current == *root
367 }
368}