ab_merkle_tree/
balanced_hashed.rs1use crate::hash_pair;
2#[cfg(feature = "alloc")]
3use alloc::boxed::Box;
4use blake3::OUT_LEN;
5use core::iter;
6use core::iter::TrustedLen;
7use core::mem::MaybeUninit;
8
9#[derive(Debug)]
27pub struct BalancedHashedMerkleTree<'a, const N: usize>
28where
29 [(); N - 1]:,
30{
31 leaves: &'a [[u8; OUT_LEN]],
32 tree: [[u8; OUT_LEN]; N - 1],
34}
35
36impl<'a, const N: usize> BalancedHashedMerkleTree<'a, N>
39where
40 [(); N - 1]:,
41{
42 pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
47 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
48
49 Self::init_internal(leaves, &mut tree);
50
51 Self {
52 leaves,
53 tree: unsafe { tree.transpose().assume_init() },
55 }
56 }
57
58 pub fn new_in<'b>(
60 instance: &'b mut MaybeUninit<Self>,
61 leaves: &'a [[u8; OUT_LEN]; N],
62 ) -> &'b mut Self {
63 let instance_ptr = instance.as_mut_ptr();
64 unsafe {
66 (&raw mut (*instance_ptr).leaves).write(leaves);
67 }
68 let tree = {
69 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
71 unsafe {
73 tree_ptr
74 .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
75 .as_mut_unchecked()
76 }
77 };
78
79 Self::init_internal(leaves, tree);
80
81 unsafe { instance.assume_init_mut() }
83 }
84
85 #[cfg(feature = "alloc")]
88 pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
89 let mut instance = Box::<Self>::new_uninit();
90
91 Self::new_in(&mut instance, leaves);
92
93 unsafe { instance.assume_init() }
95 }
96
97 fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
98 let mut tree_hashes = tree.as_mut_slice();
99 let mut level_hashes = leaves.as_slice();
100
101 let mut pair = [0u8; OUT_LEN * 2];
102 while level_hashes.len() > 1 {
103 let num_pairs = level_hashes.len() / 2;
104 let parent_hashes;
105 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
108
109 for pair_index in 0..num_pairs {
110 let left_hash = unsafe { level_hashes.get_unchecked(pair_index * 2) };
112 let right_hash = unsafe { level_hashes.get_unchecked(pair_index * 2 + 1) };
114 let parent_hash = unsafe { parent_hashes.get_unchecked_mut(pair_index) };
116
117 pair[..OUT_LEN].copy_from_slice(left_hash);
118 pair[OUT_LEN..].copy_from_slice(right_hash);
119
120 parent_hash.write(hash_pair(left_hash, right_hash));
121 }
122
123 level_hashes = unsafe { parent_hashes.assume_init_ref() };
125 }
126 }
127
128 #[inline]
133 pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
134 where
135 [(); N.ilog2() as usize + 1]:,
136 {
137 if leaves.len() == 1 {
138 return leaves[0];
139 }
140
141 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
143 let mut active_levels = 0_u32;
145
146 for &hash in leaves {
147 let mut current = hash;
148 let mut level = 0;
149
150 while active_levels & (1 << level) != 0 {
152 current = hash_pair(&stack[level], ¤t);
153
154 active_levels &= !(1 << level);
156 level += 1;
157 }
158
159 stack[level] = current;
161 active_levels |= 1 << level;
163 }
164
165 stack[N.ilog2() as usize]
166 }
167
168 #[inline]
172 pub fn root(&self) -> [u8; OUT_LEN] {
173 *self
174 .tree
175 .last()
176 .or(self.leaves.last())
177 .expect("There is always at least one leaf hash; qed")
178 }
179
180 pub fn all_proofs(
182 &self,
183 ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
184 where
185 [(); N.ilog2() as usize]:,
186 {
187 let iter = self
188 .leaves
189 .array_chunks()
190 .enumerate()
191 .flat_map(|(pair_index, &[left_hash, right_hash])| {
192 let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
193 left_proof[0].write(right_hash);
194
195 let left_proof = {
196 let (_, shared_proof) = left_proof.split_at_mut(1);
197
198 let mut tree_hashes = self.tree.as_slice();
199 let mut parent_position = pair_index;
200 let mut parent_level_size = N / 2;
201
202 for hash in shared_proof {
203 let parent_other_position = if parent_position % 2 == 0 {
204 parent_position + 1
205 } else {
206 parent_position - 1
207 };
208 let other_hash =
210 unsafe { tree_hashes.get_unchecked(parent_other_position) };
211 hash.write(*other_hash);
212 (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
213
214 parent_position /= 2;
215 parent_level_size /= 2;
216 }
217
218 unsafe { left_proof.transpose().assume_init() }
220 };
221
222 let mut right_proof = left_proof;
223 right_proof[0] = left_hash;
224
225 [left_proof, right_proof]
226 })
227 .chain({
230 let mut returned = false;
231
232 iter::from_fn(move || {
233 if N == 1 && !returned {
234 returned = true;
235 Some([[0; OUT_LEN]; N.ilog2() as usize])
236 } else {
237 None
238 }
239 })
240 });
241
242 ProofsIterator { iter, len: N }
243 }
244
245 #[inline]
247 pub fn verify(
248 root: &[u8; OUT_LEN],
249 proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
250 leaf_index: usize,
251 leaf: [u8; OUT_LEN],
252 ) -> bool
253 where
254 [(); N.ilog2() as usize]:,
255 {
256 if leaf_index >= N {
257 return false;
258 }
259
260 let mut computed_root = leaf;
261
262 let mut position = leaf_index;
263 for hash in proof {
264 computed_root = if position % 2 == 0 {
265 hash_pair(&computed_root, hash)
266 } else {
267 hash_pair(hash, &computed_root)
268 };
269
270 position /= 2;
271 }
272
273 root == &computed_root
274 }
275}
276
277struct ProofsIterator<Iter> {
278 iter: Iter,
279 len: usize,
280}
281
282impl<Iter> Iterator for ProofsIterator<Iter>
283where
284 Iter: Iterator,
285{
286 type Item = Iter::Item;
287
288 #[inline(always)]
289 fn next(&mut self) -> Option<Self::Item> {
290 let item = self.iter.next();
291 self.len = self.len.saturating_sub(1);
292 item
293 }
294
295 #[inline(always)]
296 fn size_hint(&self) -> (usize, Option<usize>) {
297 (self.len, Some(self.len))
298 }
299
300 #[inline(always)]
301 fn count(self) -> usize
302 where
303 Self: Sized,
304 {
305 self.len
306 }
307}
308
309impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
310where
311 Iter: Iterator,
312{
313 #[inline(always)]
314 fn len(&self) -> usize {
315 self.len
316 }
317}
318
319unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}