ab_merkle_tree/
unbalanced.rs1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30impl UnbalancedMerkleTree {
36 #[inline]
44 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
45 pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
46 leaves: Iter,
47 ) -> Option<[u8; OUT_LEN]>
48 where
49 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
50 Item: Into<[u8; OUT_LEN]>,
51 Iter: IntoIterator<Item = Item> + 'a,
52 {
53 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
55 let mut num_leaves = 0_u64;
56
57 for hash in leaves {
58 if num_leaves >= MAX_N {
61 return None;
62 }
63
64 let mut current = hash.into();
65
66 let lowest_active_levels = num_leaves.trailing_ones() as usize;
68 for item in stack.iter().take(lowest_active_levels) {
69 current = hash_pair(item, ¤t);
70 }
71
72 stack[lowest_active_levels] = current;
74 num_leaves += 1;
75 }
76
77 if num_leaves == 0 {
78 return None;
80 }
81
82 let mut stack_bits = num_leaves;
83
84 {
85 let lowest_active_level = stack_bits.trailing_zeros() as usize;
86 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
89 stack_bits &= !(1 << lowest_active_level);
91 }
92
93 loop {
95 let lowest_active_level = stack_bits.trailing_zeros() as usize;
96
97 if lowest_active_level == u64::BITS as usize {
98 break;
99 }
100
101 stack_bits &= !(1 << lowest_active_level);
103
104 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
106
107 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
108 }
109
110 Some(stack[0])
111 }
112
113 #[inline]
120 #[cfg(feature = "alloc")]
121 pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
122 leaves: Iter,
123 target_index: usize,
124 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
125 where
126 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
127 Item: Into<[u8; OUT_LEN]>,
128 Iter: IntoIterator<Item = Item> + 'a,
129 {
130 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
132 let mut proof = unsafe {
134 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize]>::new_uninit().assume_init()
135 };
136
137 let (root, proof_length) =
138 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
139
140 let proof_capacity = proof.len();
141 let proof = Box::into_raw(proof);
142 let proof = unsafe {
145 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
146 };
147
148 Some((root, proof))
149 }
150
151 #[inline]
158 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
159 pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
160 leaves: Iter,
161 target_index: usize,
162 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
163 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
164 where
165 [(); MAX_N.next_power_of_two().ilog2() as usize + 1]:,
166 Item: Into<[u8; OUT_LEN]>,
167 Iter: IntoIterator<Item = Item> + 'a,
168 {
169 let mut stack = [[0u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1];
171
172 let (root, proof_length) =
173 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
174 let proof = unsafe { proof.get_unchecked_mut(..proof_length).assume_init_mut() };
176
177 Some((root, proof))
178 }
179
180 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
181 fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
182 leaves: Iter,
183 target_index: usize,
184 stack: &mut [[u8; OUT_LEN]; MAX_N.next_power_of_two().ilog2() as usize + 1],
185 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.next_power_of_two().ilog2() as usize],
186 ) -> Option<([u8; OUT_LEN], usize)>
187 where
188 Item: Into<[u8; OUT_LEN]>,
189 Iter: IntoIterator<Item = Item> + 'a,
190 {
191 let mut proof_length = 0;
192 let mut num_leaves = 0_u64;
193
194 let mut current_target_level = None;
195 let mut position = target_index;
196
197 for (current_index, hash) in leaves.into_iter().enumerate() {
198 if num_leaves >= MAX_N {
201 return None;
202 }
203
204 let mut current = hash.into();
205
206 let lowest_active_levels = num_leaves.trailing_ones() as usize;
208
209 if current_index == target_index {
210 for item in stack.iter().take(lowest_active_levels) {
211 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
214 proof_length += 1;
215
216 current = hash_pair(item, ¤t);
217
218 position /= 2;
220 }
221
222 current_target_level = Some(lowest_active_levels);
223 } else {
224 for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
225 if current_target_level == Some(level) {
226 unsafe { proof.get_unchecked_mut(proof_length) }.write(
228 if position.is_multiple_of(2) {
229 current
230 } else {
231 *item
232 },
233 );
234 proof_length += 1;
235
236 current_target_level = Some(level + 1);
237
238 position /= 2;
240 }
241
242 current = hash_pair(item, ¤t);
243 }
244 }
245
246 stack[lowest_active_levels] = current;
248 num_leaves += 1;
249 }
250
251 if target_index >= num_leaves as usize {
253 return None;
255 }
256
257 let Some(current_target_level) = current_target_level else {
258 return None;
260 };
261
262 let mut stack_bits = num_leaves;
263
264 {
265 let lowest_active_level = stack_bits.trailing_zeros() as usize;
266 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
269 stack_bits &= !(1 << lowest_active_level);
271 }
272
273 let mut merged_peaks = false;
276 loop {
277 let lowest_active_level = stack_bits.trailing_zeros() as usize;
278
279 if lowest_active_level == u64::BITS as usize {
280 break;
281 }
282
283 stack_bits &= !(1 << lowest_active_level);
285
286 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
288
289 if lowest_active_level > current_target_level
290 || (lowest_active_level == current_target_level
291 && !position.is_multiple_of(2)
292 && !merged_peaks)
293 {
294 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
296 proof_length += 1;
297 merged_peaks = false;
298 } else if lowest_active_level == current_target_level {
299 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
301 proof_length += 1;
302 merged_peaks = false;
303 } else {
304 merged_peaks = true;
306 }
307
308 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
310
311 position /= 2;
312 }
313
314 Some((stack[0], proof_length))
315 }
316
317 #[inline]
319 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
320 pub fn verify(
321 root: &[u8; OUT_LEN],
322 proof: &[[u8; OUT_LEN]],
323 leaf_index: u64,
324 leaf: [u8; OUT_LEN],
325 num_leaves: u64,
326 ) -> bool {
327 if leaf_index >= num_leaves {
328 return false;
329 }
330
331 let mut current = leaf;
332 let mut position = leaf_index;
333 let mut proof_pos = 0;
334 let mut level_size = num_leaves;
335
336 while level_size > 1 {
338 let is_left = position.is_multiple_of(2);
339 let is_last = position == level_size - 1;
340
341 if is_left && !is_last {
342 if proof_pos >= proof.len() {
344 return false;
346 }
347 current = hash_pair(¤t, &proof[proof_pos]);
348 proof_pos += 1;
349 } else if !is_left {
350 if proof_pos >= proof.len() {
352 return false;
354 }
355 current = hash_pair(&proof[proof_pos], ¤t);
356 proof_pos += 1;
357 } else {
358 }
360
361 position /= 2;
362 level_size = level_size.div_ceil(2);
364 }
365
366 proof_pos == proof.len() && current == *root
368 }
369}