ab_merkle_tree/unbalanced.rs
1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedMerkleTree`] is
13/// more efficient and should be preferred when possible.
14///
15/// [`BalancedMerkleTree`]: crate::balanced::BalancedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19/// Root
20/// /--------------\
21/// H3 H4
22/// /-------\ /----\
23/// H0 H1 H2 \
24/// / \ / \ / \ \
25/// L0 L1 L2 L3 L4 L5 L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31// https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33// of power-of-two elements that can be processed in parallel and do it recursively until a single
34// element is left. This can be done for both root creation and proof generation.
35impl UnbalancedMerkleTree {
36 /// Compute Merkle Tree Root.
37 ///
38 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39 /// usage.
40 ///
41 /// Returns `None` for an empty list of leaves.
42 #[inline]
43 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
44 pub fn compute_root_only<'a, const MAX_N: u64, Item, Iter>(
45 leaves: Iter,
46 ) -> Option<[u8; OUT_LEN]>
47 where
48 [(); MAX_N.ilog2() as usize + 1]:,
49 Item: Into<[u8; OUT_LEN]>,
50 Iter: IntoIterator<Item = Item> + 'a,
51 {
52 // Stack of intermediate nodes per tree level
53 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
54 let mut num_leaves = 0_u64;
55
56 for hash in leaves {
57 // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
58 // `>=` helps compiler with panic safety checks.
59 if num_leaves >= MAX_N {
60 return None;
61 }
62
63 let mut current = hash.into();
64
65 // Every bit set to `1` corresponds to an active Merkle Tree level
66 let lowest_active_levels = num_leaves.trailing_ones() as usize;
67 for item in stack.iter().take(lowest_active_levels) {
68 current = hash_pair(item, ¤t);
69 }
70
71 // Place the current hash at the first inactive level
72 stack[lowest_active_levels] = current;
73 num_leaves += 1;
74 }
75
76 if num_leaves == 0 {
77 // If no leaves were provided
78 return None;
79 }
80
81 let mut stack_bits = num_leaves;
82
83 {
84 let lowest_active_level = stack_bits.trailing_zeros() as usize;
85 // Reuse `stack[0]` for resulting value
86 // SAFETY: Active level must have been set successfully before, hence it exists
87 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
88 // Clear lowest active level
89 stack_bits &= !(1 << lowest_active_level);
90 }
91
92 // Hash remaining peaks (if any) of the potentially unbalanced tree together
93 loop {
94 let lowest_active_level = stack_bits.trailing_zeros() as usize;
95
96 if lowest_active_level == u64::BITS as usize {
97 break;
98 }
99
100 // Clear lowest active level for next iteration
101 stack_bits &= !(1 << lowest_active_level);
102
103 // SAFETY: Active level must have been set successfully before, hence it exists
104 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
105
106 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
107 }
108
109 Some(stack[0])
110 }
111
112 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
113 ///
114 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
115 ///
116 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
117 /// usage.
118 #[inline]
119 #[cfg(feature = "alloc")]
120 pub fn compute_root_and_proof<'a, const MAX_N: u64, Item, Iter>(
121 leaves: Iter,
122 target_index: usize,
123 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
124 where
125 [(); MAX_N.ilog2() as usize + 1]:,
126 Item: Into<[u8; OUT_LEN]>,
127 Iter: IntoIterator<Item = Item> + 'a,
128 {
129 // Stack of intermediate nodes per tree level
130 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
131 // SAFETY: Inner value is `MaybeUninit`
132 let mut proof = unsafe {
133 Box::<[MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1]>::new_uninit()
134 .assume_init()
135 };
136
137 let (root, proof_length) =
138 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
139
140 let proof_capacity = proof.len();
141 let proof = Box::into_raw(proof);
142 // SAFETY: Points to correctly allocated memory where `proof_length` elements were
143 // initialized
144 let proof = unsafe {
145 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
146 };
147
148 Some((root, proof))
149 }
150
151 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
152 ///
153 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
154 ///
155 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
156 /// usage.
157 #[inline]
158 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
159 pub fn compute_root_and_proof_in<'a, 'proof, const MAX_N: u64, Item, Iter>(
160 leaves: Iter,
161 target_index: usize,
162 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
163 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
164 where
165 [(); MAX_N.ilog2() as usize + 1]:,
166 Item: Into<[u8; OUT_LEN]>,
167 Iter: IntoIterator<Item = Item> + 'a,
168 {
169 // Stack of intermediate nodes per tree level
170 let mut stack = [[0u8; OUT_LEN]; MAX_N.ilog2() as usize + 1];
171
172 let (root, proof_length) =
173 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
174 // SAFETY: Just correctly initialized `proof_length` elements
175 let proof = unsafe {
176 proof
177 .split_at_mut_unchecked(proof_length)
178 .0
179 .assume_init_mut()
180 };
181
182 Some((root, proof))
183 }
184
185 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
186 fn compute_root_and_proof_inner<'a, const MAX_N: u64, Item, Iter>(
187 leaves: Iter,
188 target_index: usize,
189 stack: &mut [[u8; OUT_LEN]; MAX_N.ilog2() as usize + 1],
190 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; MAX_N.ilog2() as usize + 1],
191 ) -> Option<([u8; OUT_LEN], usize)>
192 where
193 [(); MAX_N.ilog2() as usize + 1]:,
194 Item: Into<[u8; OUT_LEN]>,
195 Iter: IntoIterator<Item = Item> + 'a,
196 {
197 let mut proof_length = 0;
198 let mut num_leaves = 0_u64;
199
200 let mut current_target_level = None;
201 let mut position = target_index;
202
203 for (current_index, hash) in leaves.into_iter().enumerate() {
204 // How many leaves were processed so far. Should have been `num_leaves == MAX_N`, but
205 // `>=` helps compiler with panic safety checks.
206 if num_leaves >= MAX_N {
207 return None;
208 }
209
210 let mut current = hash.into();
211
212 // Every bit set to `1` corresponds to an active Merkle Tree level
213 let lowest_active_levels = num_leaves.trailing_ones() as usize;
214
215 if current_index == target_index {
216 for item in stack.iter().take(lowest_active_levels) {
217 // If at the target leaf index, need to collect the proof
218 // SAFETY: Method signature guarantees upper bound of the proof length
219 unsafe { proof.get_unchecked_mut(proof_length) }.write(*item);
220 proof_length += 1;
221
222 current = hash_pair(item, ¤t);
223
224 // Move up the tree
225 position /= 2;
226 }
227
228 current_target_level = Some(lowest_active_levels);
229 } else {
230 for (level, item) in stack.iter().enumerate().take(lowest_active_levels) {
231 if current_target_level == Some(level) {
232 // SAFETY: Method signature guarantees upper bound of the proof length
233 unsafe { proof.get_unchecked_mut(proof_length) }.write(
234 if position.is_multiple_of(2) {
235 current
236 } else {
237 *item
238 },
239 );
240 proof_length += 1;
241
242 current_target_level = Some(level + 1);
243
244 // Move up the tree
245 position /= 2;
246 }
247
248 current = hash_pair(item, ¤t);
249 }
250 }
251
252 // Place the current hash at the first inactive level
253 stack[lowest_active_levels] = current;
254 num_leaves += 1;
255 }
256
257 // `active_levels` here contains the number of leaves after above loop
258 if target_index >= num_leaves as usize {
259 // If no leaves were provided
260 return None;
261 }
262
263 let Some(current_target_level) = current_target_level else {
264 // Index not found
265 return None;
266 };
267
268 let mut stack_bits = num_leaves;
269
270 {
271 let lowest_active_level = stack_bits.trailing_zeros() as usize;
272 // Reuse `stack[0]` for resulting value
273 // SAFETY: Active level must have been set successfully before, hence it exists
274 stack[0] = *unsafe { stack.get_unchecked(lowest_active_level) };
275 // Clear lowest active level
276 stack_bits &= !(1 << lowest_active_level);
277 }
278
279 // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
280 // proof hashes
281 let mut merged_peaks = false;
282 loop {
283 let lowest_active_level = stack_bits.trailing_zeros() as usize;
284
285 if lowest_active_level == u64::BITS as usize {
286 break;
287 }
288
289 // Clear lowest active level for next iteration
290 stack_bits &= !(1 << lowest_active_level);
291
292 // SAFETY: Active level must have been set successfully before, hence it exists
293 let lowest_active_level_item = unsafe { stack.get_unchecked(lowest_active_level) };
294
295 if lowest_active_level > current_target_level
296 || (lowest_active_level == current_target_level
297 && !position.is_multiple_of(2)
298 && !merged_peaks)
299 {
300 // SAFETY: Method signature guarantees upper bound of the proof length
301 unsafe { proof.get_unchecked_mut(proof_length) }.write(*lowest_active_level_item);
302 proof_length += 1;
303 merged_peaks = false;
304 } else if lowest_active_level == current_target_level {
305 // SAFETY: Method signature guarantees upper bound of the proof length
306 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
307 proof_length += 1;
308 merged_peaks = false;
309 } else {
310 // Not collecting proof because of the need to merge peaks of an unbalanced tree
311 merged_peaks = true;
312 }
313
314 // Collect the lowest peak into the proof
315 stack[0] = hash_pair(lowest_active_level_item, &stack[0]);
316
317 position /= 2;
318 }
319
320 Some((stack[0], proof_length))
321 }
322
323 /// Verify a Merkle proof for a leaf at the given index
324 #[inline]
325 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
326 pub fn verify(
327 root: &[u8; OUT_LEN],
328 proof: &[[u8; OUT_LEN]],
329 leaf_index: u64,
330 leaf: [u8; OUT_LEN],
331 num_leaves: u64,
332 ) -> bool {
333 if leaf_index >= num_leaves {
334 return false;
335 }
336
337 let mut current = leaf;
338 let mut position = leaf_index;
339 let mut proof_pos = 0;
340 let mut level_size = num_leaves;
341
342 // Rebuild the path to the root
343 while level_size > 1 {
344 let is_left = position.is_multiple_of(2);
345 let is_last = position == level_size - 1;
346
347 if is_left && !is_last {
348 // Left node with a right sibling
349 if proof_pos >= proof.len() {
350 // Missing sibling
351 return false;
352 }
353 current = hash_pair(¤t, &proof[proof_pos]);
354 proof_pos += 1;
355 } else if !is_left {
356 // Right node with a left sibling
357 if proof_pos >= proof.len() {
358 // Missing sibling
359 return false;
360 }
361 current = hash_pair(&proof[proof_pos], ¤t);
362 proof_pos += 1;
363 } else {
364 // Last node, no sibling, keep current
365 }
366
367 position /= 2;
368 // Size of next level
369 level_size = level_size.div_ceil(2);
370 }
371
372 // Check if proof is fully used and matches root
373 proof_pos == proof.len() && current == *root
374 }
375}