ab_merkle_tree/unbalanced_hashed.rs
1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6use blake3::OUT_LEN;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedHashedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedHashedMerkleTree`]
13/// is more efficient and should be preferred when possible.
14///
15/// [`BalancedHashedMerkleTree`]: crate::balanced_hashed::BalancedHashedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```text
19/// Root
20/// /--------------\
21/// H3 H4
22/// /-------\ /----\
23/// H0 H1 H2 \
24/// / \ / \ / \ \
25/// L0 L1 L2 L3 L4 L5 L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedHashedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31// https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33// of power-of-two elements that can be processed in parallel and do it recursively until a single
34// element is left. This can be done for both root creation and proof generation.
35impl UnbalancedHashedMerkleTree {
36 /// Compute Merkle Tree Root.
37 ///
38 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39 /// usage.
40 ///
41 /// Returns `None` for an empty list of leaves.
42 #[inline]
43 pub fn compute_root_only<'a, const N: usize, Item, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
44 where
45 [(); N.ilog2() as usize + 1]:,
46 Item: Into<[u8; OUT_LEN]>,
47 Iter: IntoIterator<Item = Item> + 'a,
48 {
49 // Stack of intermediate nodes per tree level
50 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
51 // Bitmask: bit `i = 1` if level `i` is active
52 let mut active_levels = 0_u64;
53
54 for hash in leaves {
55 let mut current = hash.into();
56 let mut level = 0;
57
58 // Check if level is active by testing bit (active_levels & (1 << level))
59 while (active_levels & (1 << level)) != 0 {
60 current = hash_pair(&stack[level], ¤t);
61
62 // Clear the current level
63 active_levels &= !(1 << level);
64 level += 1;
65 }
66
67 // Place the current hash at the first inactive level
68 stack[level] = current;
69 // Set bit for level
70 active_levels |= 1 << level;
71 }
72
73 if active_levels == 0 {
74 // If no leaves were provided
75 return None;
76 }
77
78 {
79 let lowest_active_level = active_levels.trailing_zeros() as usize;
80 // Reuse `stack[0]` for resulting value
81 stack[0] = stack[lowest_active_level];
82 // Clear lowest active level
83 active_levels &= !(1 << lowest_active_level);
84 }
85
86 // Hash remaining peaks (if any) of the potentially unbalanced tree together
87 loop {
88 let lowest_active_level = active_levels.trailing_zeros() as usize;
89
90 if lowest_active_level == u64::BITS as usize {
91 break;
92 }
93
94 // Clear lowest active level
95 active_levels &= !(1 << lowest_active_level);
96
97 stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
98 }
99
100 Some(stack[0])
101 }
102
103 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
104 ///
105 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
106 ///
107 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
108 /// usage.
109 #[inline]
110 #[cfg(feature = "alloc")]
111 pub fn compute_root_and_proof<'a, const N: usize, Item, Iter>(
112 leaves: Iter,
113 target_index: usize,
114 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
115 where
116 [(); N.ilog2() as usize + 1]:,
117 Item: Into<[u8; OUT_LEN]>,
118 Iter: IntoIterator<Item = Item> + 'a,
119 {
120 // Stack of intermediate nodes per tree level
121 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
122 // SAFETY: Inner value is `MaybeUninit`
123 let mut proof = unsafe {
124 Box::<[MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1]>::new_uninit().assume_init()
125 };
126
127 let (root, proof_length) =
128 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
129
130 let proof_capacity = proof.len();
131 let proof = Box::into_raw(proof);
132 // SAFETY: Points to correctly allocated memory where `proof_length` elements were
133 // initialized
134 let proof = unsafe {
135 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
136 };
137
138 Some((root, proof))
139 }
140
141 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
142 ///
143 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
144 ///
145 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
146 /// usage.
147 #[inline]
148 pub fn compute_root_and_proof_in<'a, 'proof, const N: usize, Item, Iter>(
149 leaves: Iter,
150 target_index: usize,
151 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
152 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
153 where
154 [(); N.ilog2() as usize + 1]:,
155 Item: Into<[u8; OUT_LEN]>,
156 Iter: IntoIterator<Item = Item> + 'a,
157 {
158 // Stack of intermediate nodes per tree level
159 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
160
161 let (root, proof_length) =
162 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
163 // SAFETY: Just correctly initialized `proof_length` elements
164 let proof = unsafe { proof[..proof_length].assume_init_mut() };
165
166 Some((root, proof))
167 }
168
169 fn compute_root_and_proof_inner<'a, const N: usize, Item, Iter>(
170 leaves: Iter,
171 target_index: usize,
172 stack: &mut [[u8; OUT_LEN]; N.ilog2() as usize + 1],
173 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
174 ) -> Option<([u8; OUT_LEN], usize)>
175 where
176 [(); N.ilog2() as usize + 1]:,
177 Item: Into<[u8; OUT_LEN]>,
178 Iter: IntoIterator<Item = Item> + 'a,
179 {
180 let mut proof_length = 0;
181 let mut active_levels = 0_u64;
182
183 let mut current_target_level = None;
184 let mut position = target_index;
185
186 for (current_index, hash) in leaves.into_iter().enumerate() {
187 let mut current = hash.into();
188 let mut level = 0;
189
190 if current_index == target_index {
191 // Check if level is active by testing bit (active_levels & (1 << level))
192 while (active_levels & (1 << level)) != 0 {
193 // If at the target leaf index, need to collect the proof
194 // SAFETY: Method signature guarantees upper bound of the proof length
195 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[level]);
196 proof_length += 1;
197
198 current = hash_pair(&stack[level], ¤t);
199
200 // Clear the current level
201 active_levels &= !(1 << level);
202 level += 1;
203
204 // Move up the tree
205 position /= 2;
206 }
207
208 current_target_level = Some(level);
209
210 // Place the current hash at the first inactive level
211 stack[level] = current;
212 // Set bit for level
213 active_levels |= 1 << level;
214 } else {
215 // If at the target leaf index, need to collect the proof
216 while (active_levels & (1 << level)) != 0 {
217 if current_target_level == Some(level) {
218 // SAFETY: Method signature guarantees upper bound of the proof length
219 unsafe { proof.get_unchecked_mut(proof_length) }.write(
220 if position % 2 == 0 {
221 current
222 } else {
223 stack[level]
224 },
225 );
226 proof_length += 1;
227
228 current_target_level = Some(level + 1);
229
230 // Move up the tree
231 position /= 2;
232 }
233
234 current = hash_pair(&stack[level], ¤t);
235
236 // Clear the current level
237 active_levels &= !(1 << level);
238 level += 1;
239 }
240
241 // Place the current hash at the first inactive level
242 stack[level] = current;
243 // Set bit for level
244 active_levels |= 1 << level;
245 }
246 }
247
248 // `active_levels` here contains the number of leaves after above loop
249 if target_index >= active_levels as usize {
250 // If no leaves were provided
251 return None;
252 }
253
254 let Some(current_target_level) = current_target_level else {
255 // Index not found
256 return None;
257 };
258
259 {
260 let lowest_active_level = active_levels.trailing_zeros() as usize;
261 // Reuse `stack[0]` for resulting value
262 stack[0] = stack[lowest_active_level];
263 // Clear lowest active level
264 active_levels &= !(1 << lowest_active_level);
265 }
266
267 // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
268 // proof hashes
269 let mut merged_peaks = false;
270 loop {
271 let lowest_active_level = active_levels.trailing_zeros() as usize;
272
273 if lowest_active_level == u64::BITS as usize {
274 break;
275 }
276
277 // Clear lowest active level
278 active_levels &= !(1 << lowest_active_level);
279
280 if lowest_active_level > current_target_level
281 || (lowest_active_level == current_target_level
282 && (position % 2 != 0)
283 && !merged_peaks)
284 {
285 // SAFETY: Method signature guarantees upper bound of the proof length
286 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[lowest_active_level]);
287 proof_length += 1;
288 merged_peaks = false;
289 } else if lowest_active_level == current_target_level {
290 // SAFETY: Method signature guarantees upper bound of the proof length
291 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
292 proof_length += 1;
293 merged_peaks = false;
294 } else {
295 // Not collecting proof because of the need to merge peaks of an unbalanced tree
296 merged_peaks = true;
297 }
298
299 // Collect the lowest peak into the proof
300 stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
301
302 position /= 2;
303 }
304
305 Some((stack[0], proof_length))
306 }
307
308 /// Verify a Merkle proof for a leaf at the given index
309 #[inline]
310 pub fn verify(
311 root: &[u8; OUT_LEN],
312 proof: &[[u8; OUT_LEN]],
313 leaf_index: usize,
314 leaf: [u8; OUT_LEN],
315 num_leaves: usize,
316 ) -> bool {
317 if leaf_index >= num_leaves {
318 return false;
319 }
320
321 let mut current = leaf;
322 let mut position = leaf_index;
323 let mut proof_pos = 0;
324 let mut level_size = num_leaves;
325
326 // Rebuild the path to the root
327 while level_size > 1 {
328 let is_left = position % 2 == 0;
329 let is_last = position == level_size - 1;
330
331 if is_left && !is_last {
332 // Left node with a right sibling
333 if proof_pos >= proof.len() {
334 // Missing sibling
335 return false;
336 }
337 current = hash_pair(¤t, &proof[proof_pos]);
338 proof_pos += 1;
339 } else if !is_left {
340 // Right node with a left sibling
341 if proof_pos >= proof.len() {
342 // Missing sibling
343 return false;
344 }
345 current = hash_pair(&proof[proof_pos], ¤t);
346 proof_pos += 1;
347 } else {
348 // Last node, no sibling, keep current
349 }
350
351 position /= 2;
352 // Size of next level
353 level_size = level_size.div_ceil(2);
354 }
355
356 // Check if proof is fully used and matches root
357 proof_pos == proof.len() && current == *root
358 }
359}