ab_merkle_tree/
unbalanced.rs1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30impl UnbalancedMerkleTree {
36 #[inline]
43 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44 pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45 leaves: Iter,
46 ) -> Option<[u8; OUT_LEN]>
47 where
48 [(); MAX_N.ilog2() as usize + 1]:,
49 Item: Into<[u8; OUT_LEN]>,
50 Iter: IntoIterator<Item = Item> + 'a,
51 {
52 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
54 let mut num_leaves = 0_u64;
55
56 for hash in leaves {
57 if num_leaves >= MAX_N {
59 return None;
60 }
61
62 let mut current = hash.into();
63
64 let lowest_active_levels = num_leaves.trailing_ones() as usize;
66 for item in stack.iter().take(lowest_active_levels) {
67 current = hash_pair(item, ¤t);
68 }
69
70 stack[lowest_active_levels] = current;
72 num_leaves += 1;
73 }
74
75 if num_leaves == 0 {
76 return None;
78 }
79
80 let mut stack_bits = num_leaves;
81
82 {
83 let lowest_active_level = stack_bits.trailing_zeros() as usize;
84 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
87 stack_bits &= !(1 << lowest_active_level);
89 }
90
91 loop {
93 let lowest_active_level = stack_bits.trailing_zeros() as usize;
94
95 if lowest_active_level == u64::BITS as usize {
96 break;
97 }
98
99 stack_bits &= !(1 << lowest_active_level);
101
102 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
104
105 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
106 }
107
108 Some(stack[0])
109 }
110
111 #[inline]
118 #[cfg(feature = "alloc")]
119 pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
120 leaves: Iter,
121 target_index: usize,
122 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
123 where
124 [(); MAX_N.ilog2() as usize + 1]:,
125 Item: Into<[u8; OUT_LEN]>,
126 Iter: IntoIterator<Item = Item> + 'a,
127 {
128 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
130 let mut proof = unsafe {
132 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
133 .assume_init()
134 };
135
136 let (root, proof_length) =
137 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
138
139 let proof_capacity = proof.len();
140 let proof = Box::into_raw(proof);
141 let proof = unsafe {
144 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
145 };
146
147 Some((root, proof))
148 }
149
150 #[inline]
157 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
158 pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
159 leaves: Iter,
160 target_index: usize,
161 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
162 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
163 where
164 [(); MAX_N.ilog2() as usize + 1]:,
165 Item: Into<[u8; OUT_LEN]>,
166 Iter: IntoIterator<Item = Item> + 'a,
167 {
168 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
170
171 let (root, proof_length) =
172 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
173 let proof = unsafe {
175 proof
176 .split_at_mut_unchecked(proof_length)
177 .0
178 .assume_init_mut()
179 };
180
181 Some((root, proof))
182 }
183
184 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
185 fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
186 leaves: Iter,
187 target_index: usize,
188 stack: &mut [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
189 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
190 ) -> Option<([u8; OUT_LEN], usize)>
191 where
192 [(); MAX_N.ilog2() as usize + 1]:,
193 Item: Into<[u8; OUT_LEN]>,
194 Iter: IntoIterator<Item = Item> + 'a,
195 {
196 let mut proof_length = 0;
197 let mut num_leaves = 0_u64;
198
199 let mut current_target_level = None;
200 let mut position = target_index;
201
202 for (current_index, hash) in leaves.into_iter().enumerate() {
203 if num_leaves >= MAX_N {
205 return None;
206 }
207
208 let mut current = hash.into();
209
210 let lowest_active_levels = num_leaves.trailing_ones() as usize;
212
213 if current_index == target_index {
214 for item in stack.iter().take(lowest_active_levels) {
215 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
218 proof_length += 1;
219
220 current = hash_pair(item, ¤t);
221
222 position /= 2;
224 }
225
226 current_target_level = Some(lowest_active_levels);
227 } else {
228 for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
229 if current_target_level == Some(level) {
230 unsafe { proof.get_unchecked_mut(proof_length) }
232 .write(if position % 2 == 0 { current } else { *item });
233 proof_length += 1;
234
235 current_target_level = Some(level + 1);
236
237 position /= 2;
239 }
240
241 current = hash_pair(item, ¤t);
242 }
243 }
244
245 stack[lowest_active_levels] = current;
247 num_leaves += 1;
248 }
249
250 if target_index >= num_leaves as usize {
252 return None;
254 }
255
256 let Some(current_target_level) = current_target_level else {
257 return None;
259 };
260
261 let mut stack_bits = num_leaves;
262
263 {
264 let lowest_active_level = stack_bits.trailing_zeros() as usize;
265 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
268 stack_bits &= !(1 << lowest_active_level);
270 }
271
272 let mut merged_peaks = false;
275 loop {
276 let lowest_active_level = stack_bits.trailing_zeros() as usize;
277
278 if lowest_active_level == u64::BITS as usize {
279 break;
280 }
281
282 stack_bits &= !(1 << lowest_active_level);
284
285 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
287
288 if lowest_active_level > current_target_level
289 || (lowest_active_level == current_target_level
290 && (position % 2 != 0)
291 && !merged_peaks)
292 {
293 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
295 proof_length += 1;
296 merged_peaks = false;
297 } else if lowest_active_level == current_target_level {
298 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
300 proof_length += 1;
301 merged_peaks = false;
302 } else {
303 merged_peaks = true;
305 }
306
307 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
309
310 position /= 2;
311 }
312
313 Some((stack[0], proof_length))
314 }
315
316 #[inline]
318 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
319 pub fn verify(
320 root: &[u8; OUT_LEN],
321 proof: &[[u8; OUT_LEN]],
322 leaf_index: u64,
323 leaf: [u8; OUT_LEN],
324 num_leaves: u64,
325 ) -> bool {
326 if leaf_index >= num_leaves {
327 return false;
328 }
329
330 let mut current = leaf;
331 let mut position = leaf_index;
332 let mut proof_pos = 0;
333 let mut level_size = num_leaves;
334
335 while level_size > 1 {
337 let is_left = position % 2 == 0;
338 let is_last = position == level_size - 1;
339
340 if is_left && !is_last {
341 if proof_pos >= proof.len() {
343 return false;
345 }
346 current = hash_pair(¤t, &proof[proof_pos]);
347 proof_pos += 1;
348 } else if !is_left {
349 if proof_pos >= proof.len() {
351 return false;
353 }
354 current = hash_pair(&proof[proof_pos], ¤t);
355 proof_pos += 1;
356 } else {
357 }
359
360 position /= 2;
361 level_size = level_size.div_ceil(2);
363 }
364
365 proof_pos == proof.len() && current == *root
367 }
368}