ab_merkle_tree/
balanced.rs1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter;
6use core::iter::TrustedLen;
7use core::mem::MaybeUninit;
8
9#[derive(Debug)]
27pub struct BalancedMerkleTree<'a, const N: usize>
28where
29 [(); N - 1]:,
30{
31 leaves: &'a [[u8; OUT_LEN]],
32 tree: [[u8; OUT_LEN]; N - 1],
34}
35
36impl<'a, const N: usize> BalancedMerkleTree<'a, N>
39where
40 [(); N - 1]:,
41{
42 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
47 pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
48 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
49
50 Self::init_internal(leaves, &mut tree);
51
52 Self {
53 leaves,
54 tree: unsafe { tree.transpose().assume_init() },
56 }
57 }
58
59 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
61 pub fn new_in<'b>(
62 instance: &'b mut MaybeUninit<Self>,
63 leaves: &'a [[u8; OUT_LEN]; N],
64 ) -> &'b mut Self {
65 let instance_ptr = instance.as_mut_ptr();
66 unsafe {
68 (&raw mut (*instance_ptr).leaves).write(leaves);
69 }
70 let tree = {
71 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
73 unsafe {
75 tree_ptr
76 .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
77 .as_mut_unchecked()
78 }
79 };
80
81 Self::init_internal(leaves, tree);
82
83 unsafe { instance.assume_init_mut() }
85 }
86
87 #[cfg(feature = "alloc")]
90 pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
91 let mut instance = Box::<Self>::new_uninit();
92
93 Self::new_in(&mut instance, leaves);
94
95 unsafe { instance.assume_init() }
97 }
98
99 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
100 fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
101 let mut tree_hashes = tree.as_mut_slice();
102 let mut level_hashes = leaves.as_slice();
103
104 let mut pair = [0u8; OUT_LEN * 2];
105 while level_hashes.len() > 1 {
106 let num_pairs = level_hashes.len() / 2;
107 let parent_hashes;
108 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
111
112 for ([left_hash, right_hash], parent_hash) in
113 level_hashes.array_chunks().zip(parent_hashes.iter_mut())
114 {
115 pair[..OUT_LEN].copy_from_slice(left_hash);
116 pair[OUT_LEN..].copy_from_slice(right_hash);
117
118 parent_hash.write(hash_pair(left_hash, right_hash));
119 }
120
121 level_hashes = unsafe { parent_hashes.assume_init_ref() };
123 }
124 }
125
126 #[inline]
131 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
132 pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
133 where
134 [(); N.ilog2() as usize + 1]:,
135 {
136 if leaves.len() == 1 {
137 return leaves[0];
138 }
139
140 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
142
143 for (num_leaves, &hash) in leaves.iter().enumerate() {
144 let mut current = hash;
145
146 let lowest_active_levels = num_leaves.trailing_ones() as usize;
148 for item in stack.iter().take(lowest_active_levels) {
149 current = hash_pair(item, ¤t);
150 }
151
152 stack[lowest_active_levels] = current;
154 }
155
156 stack[N.ilog2() as usize]
157 }
158
159 #[inline]
163 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
164 pub fn root(&self) -> [u8; OUT_LEN] {
165 *self
166 .tree
167 .last()
168 .or(self.leaves.last())
169 .expect("There is always at least one leaf hash; qed")
170 }
171
172 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
174 pub fn all_proofs(
175 &self,
176 ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
177 where
178 [(); N.ilog2() as usize]:,
179 {
180 let iter = self
181 .leaves
182 .array_chunks()
183 .enumerate()
184 .flat_map(|(pair_index, &[left_hash, right_hash])| {
185 let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
186 left_proof[0].write(right_hash);
187
188 let left_proof = {
189 let (_, shared_proof) = left_proof.split_at_mut(1);
190
191 let mut tree_hashes = self.tree.as_slice();
192 let mut parent_position = pair_index;
193 let mut parent_level_size = N / 2;
194
195 for hash in shared_proof {
196 let parent_other_position = parent_position ^ 1;
203
204 let other_hash =
206 unsafe { tree_hashes.get_unchecked(parent_other_position) };
207 hash.write(*other_hash);
208 (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
209
210 parent_position /= 2;
211 parent_level_size /= 2;
212 }
213
214 unsafe { left_proof.transpose().assume_init() }
216 };
217
218 let mut right_proof = left_proof;
219 right_proof[0] = left_hash;
220
221 [left_proof, right_proof]
222 })
223 .chain({
226 let mut returned = false;
227
228 iter::from_fn(move || {
229 if N == 1 && !returned {
230 returned = true;
231 Some([[0; OUT_LEN]; N.ilog2() as usize])
232 } else {
233 None
234 }
235 })
236 });
237
238 ProofsIterator { iter, len: N }
239 }
240
241 #[inline]
243 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
244 pub fn verify(
245 root: &[u8; OUT_LEN],
246 proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
247 leaf_index: usize,
248 leaf: [u8; OUT_LEN],
249 ) -> bool
250 where
251 [(); N.ilog2() as usize]:,
252 {
253 if leaf_index >= N {
254 return false;
255 }
256
257 let mut computed_root = leaf;
258
259 let mut position = leaf_index;
260 for hash in proof {
261 computed_root = if position % 2 == 0 {
262 hash_pair(&computed_root, hash)
263 } else {
264 hash_pair(hash, &computed_root)
265 };
266
267 position /= 2;
268 }
269
270 root == &computed_root
271 }
272}
273
274struct ProofsIterator<Iter> {
275 iter: Iter,
276 len: usize,
277}
278
279impl<Iter> Iterator for ProofsIterator<Iter>
280where
281 Iter: Iterator,
282{
283 type Item = Iter::Item;
284
285 #[inline(always)]
286 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
287 fn next(&mut self) -> Option<Self::Item> {
288 let item = self.iter.next();
289 self.len = self.len.saturating_sub(1);
290 item
291 }
292
293 #[inline(always)]
294 fn size_hint(&self) -> (usize, Option<usize>) {
295 (self.len, Some(self.len))
296 }
297
298 #[inline(always)]
299 fn count(self) -> usize
300 where
301 Self: Sized,
302 {
303 self.len
304 }
305}
306
307impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
308where
309 Iter: Iterator,
310{
311 #[inline(always)]
312 fn len(&self) -> usize {
313 self.len
314 }
315}
316
317unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}