ab_merkle_tree/unbalanced_hashed.rs
1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4#[cfg(feature = "alloc")]
5use alloc::vec::Vec;
6use blake3::OUT_LEN;
7use core::mem::MaybeUninit;
8
9/// Merkle Tree variant that has pre-hashed leaves with arbitrary number of elements.
10///
11/// This can be considered a general case of [`BalancedHashedMerkleTree`]. The root and proofs are
12/// identical for both in case the number of leaves is a power of two. [`BalancedHashedMerkleTree`]
13/// is more efficient and should be preferred when possible.
14///
15/// [`BalancedHashedMerkleTree`]: crate::balanced_hashed::BalancedHashedMerkleTree
16///
17/// The unbalanced tree is not padded, it is created the same way Merkle Mountain Range would be:
18/// ```ignore
19/// Root
20/// /--------------\
21/// H3 H4
22/// /-------\ /----\
23/// H0 H1 H2 \
24/// / \ / \ / \ \
25/// L0 L1 L2 L3 L4 L5 L6
26/// ```
27#[derive(Debug)]
28pub struct UnbalancedHashedMerkleTree;
29
30// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
31// https://github.com/BLAKE3-team/BLAKE3/issues/478
32// TODO: Experiment with replacing a single pass with splitting the whole data set with a sequence
33// of power-of-two elements that can be processed in parallel and do it recursively until a single
34// element is left. This can be done for both root creation and proof generation.
35impl UnbalancedHashedMerkleTree {
36 /// Compute Merkle Tree Root.
37 ///
38 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
39 /// usage.
40 #[inline]
41 pub fn compute_root_only<'a, const N: usize, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
42 where
43 [(); N.ilog2() as usize + 1]:,
44 Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
45 {
46 // Stack of intermediate nodes per tree level
47 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
48 // Bitmask: bit `i = 1` if level `i` is active
49 let mut active_levels = 0_u64;
50
51 for &hash in leaves {
52 let mut current = hash;
53 let mut level = 0;
54
55 // Check if level is active by testing bit (active_levels & (1 << level))
56 while (active_levels & (1 << level)) != 0 {
57 current = hash_pair(&stack[level], ¤t);
58
59 // Clear the current level
60 active_levels &= !(1 << level);
61 level += 1;
62 }
63
64 // Place the current hash at the first inactive level
65 stack[level] = current;
66 // Set bit for level
67 active_levels |= 1 << level;
68 }
69
70 if active_levels == 0 {
71 // If no leaves were provided
72 return None;
73 }
74
75 {
76 let lowest_active_level = active_levels.trailing_zeros() as usize;
77 // Reuse `stack[0]` for resulting value
78 stack[0] = stack[lowest_active_level];
79 // Clear lowest active level
80 active_levels &= !(1 << lowest_active_level);
81 }
82
83 // Hash remaining peaks (if any) of the potentially unbalanced tree together
84 loop {
85 let lowest_active_level = active_levels.trailing_zeros() as usize;
86
87 if lowest_active_level == u64::BITS as usize {
88 break;
89 }
90
91 // Clear lowest active level
92 active_levels &= !(1 << lowest_active_level);
93
94 stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
95 }
96
97 Some(stack[0])
98 }
99
100 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
101 ///
102 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
103 ///
104 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
105 /// usage.
106 #[inline]
107 #[cfg(feature = "alloc")]
108 pub fn compute_root_and_proof<'a, const N: usize, Iter>(
109 leaves: Iter,
110 target_index: usize,
111 ) -> Option<([u8; OUT_LEN], Vec<[u8; OUT_LEN]>)>
112 where
113 [(); N.ilog2() as usize + 1]:,
114 Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
115 {
116 // Stack of intermediate nodes per tree level
117 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
118 // SAFETY: Inner value is `MaybeUninit`
119 let mut proof = unsafe {
120 Box::<[MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1]>::new_uninit().assume_init()
121 };
122
123 let (root, proof_length) =
124 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, &mut proof)?;
125
126 let proof_capacity = proof.len();
127 let proof = Box::into_raw(proof);
128 // SAFETY: Points to correctly allocated memory where `proof_length` elements were
129 // initialized
130 let proof = unsafe {
131 Vec::from_raw_parts(proof.cast::<[u8; OUT_LEN]>(), proof_length, proof_capacity)
132 };
133
134 Some((root, proof))
135 }
136
137 /// Compute Merkle Tree root and generate a proof for the `leaf` at `target_index`.
138 ///
139 /// Returns `Some(root, proof)` on success, `None` if index is outside of list of leaves.
140 ///
141 /// `MAX_N` generic constant defines the maximum number of elements supported and controls stack
142 /// usage.
143 #[inline]
144 pub fn compute_root_and_proof_in<'a, 'proof, const N: usize, Iter>(
145 leaves: Iter,
146 target_index: usize,
147 proof: &'proof mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
148 ) -> Option<([u8; OUT_LEN], &'proof mut [[u8; OUT_LEN]])>
149 where
150 [(); N.ilog2() as usize + 1]:,
151 Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
152 {
153 // Stack of intermediate nodes per tree level
154 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
155
156 let (root, proof_length) =
157 Self::compute_root_and_proof_inner(leaves, target_index, &mut stack, proof)?;
158 // SAFETY: Just correctly initialized `proof_length` elements
159 let proof = unsafe { proof[..proof_length].assume_init_mut() };
160
161 Some((root, proof))
162 }
163
164 fn compute_root_and_proof_inner<'a, const N: usize, Iter>(
165 leaves: Iter,
166 target_index: usize,
167 stack: &mut [[u8; OUT_LEN]; N.ilog2() as usize + 1],
168 proof: &mut [MaybeUninit<[u8; OUT_LEN]>; N.ilog2() as usize + 1],
169 ) -> Option<([u8; OUT_LEN], usize)>
170 where
171 [(); N.ilog2() as usize + 1]:,
172 Iter: Iterator<Item = &'a [u8; OUT_LEN]> + 'a,
173 {
174 let mut proof_length = 0;
175 let mut active_levels = 0_u64;
176
177 let mut current_target_level = None;
178 let mut position = target_index;
179
180 for (current_index, &hash) in leaves.enumerate() {
181 let mut current = hash;
182 let mut level = 0;
183
184 if current_index == target_index {
185 // Check if level is active by testing bit (active_levels & (1 << level))
186 while (active_levels & (1 << level)) != 0 {
187 // If at the target leaf index, need to collect the proof
188 // SAFETY: Method signature guarantees upper bound of the proof length
189 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[level]);
190 proof_length += 1;
191
192 current = hash_pair(&stack[level], ¤t);
193
194 // Clear the current level
195 active_levels &= !(1 << level);
196 level += 1;
197
198 // Move up the tree
199 position /= 2;
200 }
201
202 current_target_level = Some(level);
203
204 // Place the current hash at the first inactive level
205 stack[level] = current;
206 // Set bit for level
207 active_levels |= 1 << level;
208 } else {
209 // If at the target leaf index, need to collect the proof
210 while (active_levels & (1 << level)) != 0 {
211 if current_target_level == Some(level) {
212 // SAFETY: Method signature guarantees upper bound of the proof length
213 unsafe { proof.get_unchecked_mut(proof_length) }.write(
214 if position % 2 == 0 {
215 current
216 } else {
217 stack[level]
218 },
219 );
220 proof_length += 1;
221
222 current_target_level = Some(level + 1);
223
224 // Move up the tree
225 position /= 2;
226 }
227
228 current = hash_pair(&stack[level], ¤t);
229
230 // Clear the current level
231 active_levels &= !(1 << level);
232 level += 1;
233 }
234
235 // Place the current hash at the first inactive level
236 stack[level] = current;
237 // Set bit for level
238 active_levels |= 1 << level;
239 }
240 }
241
242 // `active_levels` here contains the number of leaves after above loop
243 if target_index >= active_levels as usize {
244 // If no leaves were provided
245 return None;
246 }
247
248 let Some(current_target_level) = current_target_level else {
249 // Index not found
250 return None;
251 };
252
253 {
254 let lowest_active_level = active_levels.trailing_zeros() as usize;
255 // Reuse `stack[0]` for resulting value
256 stack[0] = stack[lowest_active_level];
257 // Clear lowest active level
258 active_levels &= !(1 << lowest_active_level);
259 }
260
261 // Hash remaining peaks (if any) of the potentially unbalanced tree together and collect
262 // proof hashes
263 let mut merged_peaks = false;
264 loop {
265 let lowest_active_level = active_levels.trailing_zeros() as usize;
266
267 if lowest_active_level == u64::BITS as usize {
268 break;
269 }
270
271 // Clear lowest active level
272 active_levels &= !(1 << lowest_active_level);
273
274 if lowest_active_level > current_target_level
275 || (lowest_active_level == current_target_level
276 && (position % 2 != 0)
277 && !merged_peaks)
278 {
279 // SAFETY: Method signature guarantees upper bound of the proof length
280 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[lowest_active_level]);
281 proof_length += 1;
282 merged_peaks = false;
283 } else if lowest_active_level == current_target_level {
284 // SAFETY: Method signature guarantees upper bound of the proof length
285 unsafe { proof.get_unchecked_mut(proof_length) }.write(stack[0]);
286 proof_length += 1;
287 merged_peaks = false;
288 } else {
289 // Not collecting proof because of the need to merge peaks of an unbalanced tree
290 merged_peaks = true;
291 }
292
293 // Collect the lowest peak into the proof
294 stack[0] = hash_pair(&stack[lowest_active_level], &stack[0]);
295
296 position /= 2;
297 }
298
299 Some((stack[0], proof_length))
300 }
301
302 /// Verify a Merkle proof for a leaf at the given index
303 #[inline]
304 // TODO: Make `num_leaves` optional in case the leaf is trusted (like just hashed from another
305 // value and guaranteed not to use the same keyed hash as used here)
306 pub fn verify(
307 root: &[u8; OUT_LEN],
308 proof: &[[u8; OUT_LEN]],
309 leaf_index: usize,
310 leaf: [u8; OUT_LEN],
311 num_leaves: usize,
312 ) -> bool {
313 if leaf_index >= num_leaves {
314 return false;
315 }
316
317 let mut current = leaf;
318 let mut position = leaf_index;
319 let mut proof_pos = 0;
320 let mut level_size = num_leaves;
321
322 // Rebuild the path to the root
323 while level_size > 1 {
324 let is_left = position % 2 == 0;
325 let is_last = position == level_size - 1;
326
327 if is_left && !is_last {
328 // Left node with a right sibling
329 if proof_pos >= proof.len() {
330 // Missing sibling
331 return false;
332 }
333 current = hash_pair(¤t, &proof[proof_pos]);
334 proof_pos += 1;
335 } else if !is_left {
336 // Right node with a left sibling
337 if proof_pos >= proof.len() {
338 // Missing sibling
339 return false;
340 }
341 current = hash_pair(&proof[proof_pos], ¤t);
342 proof_pos += 1;
343 } else {
344 // Last node, no sibling, keep current
345 }
346
347 position /= 2;
348 // Size of next level
349 level_size = level_size.div_ceil(2);
350 }
351
352 // Check if proof is fully used and matches root
353 proof_pos == proof.len() && current == *root
354 }
355}