ab_merkle_tree/
balanced.rs1use crate::hash_pair;
2use ab_blake3::OUT_LEN;
3#[cfg(feature = "alloc")]
4use alloc::boxed::Box;
5use core::iter::TrustedLen;
6use core::mem::MaybeUninit;
7use core::num::NonZero;
8
9pub const fn ensure_supported_n(n: usize) -> usize {
14 assert!(
15 n.is_power_of_two(),
16 "Balanced Merkle Tree must have a number of leaves that is a power of 2"
17 );
18
19 assert!(
20 n > 1,
21 "This Balanced Merkle Tree must have more than one leaf"
22 );
23
24 0
25}
26
27#[derive(Debug)]
45pub struct BalancedMerkleTree<'a, const N: usize>
46where
47 [(); N - 1]:,
48{
49 leaves: &'a [[u8; OUT_LEN]],
50 tree: [[u8; OUT_LEN]; N - 1],
52}
53
54impl<'a, const N: usize> BalancedMerkleTree<'a, N>
57where
58 [(); N - 1]:,
59 [(); ensure_supported_n(N)]:,
60{
61 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
66 pub fn new(leaves: &'a [[u8; OUT_LEN]; N]) -> Self {
67 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); _];
68
69 Self::init_internal(leaves, &mut tree);
70
71 Self {
72 leaves,
73 tree: unsafe { tree.transpose().assume_init() },
75 }
76 }
77
78 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
80 pub fn new_in<'b>(
81 instance: &'b mut MaybeUninit<Self>,
82 leaves: &'a [[u8; OUT_LEN]; N],
83 ) -> &'b mut Self {
84 let instance_ptr = instance.as_mut_ptr();
85 unsafe {
87 (&raw mut (*instance_ptr).leaves).write(leaves);
88 }
89 let tree = {
90 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
92 unsafe {
94 tree_ptr
95 .cast::<[MaybeUninit<[u8; OUT_LEN]>; N - 1]>()
96 .as_mut_unchecked()
97 }
98 };
99
100 Self::init_internal(leaves, tree);
101
102 unsafe { instance.assume_init_mut() }
104 }
105
106 #[cfg(feature = "alloc")]
109 pub fn new_boxed(leaves: &'a [[u8; OUT_LEN]; N]) -> Box<Self> {
110 let mut instance = Box::<Self>::new_uninit();
111
112 Self::new_in(&mut instance, leaves);
113
114 unsafe { instance.assume_init() }
116 }
117
118 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
119 fn init_internal(leaves: &[[u8; OUT_LEN]; N], tree: &mut [MaybeUninit<[u8; OUT_LEN]>; N - 1]) {
120 let mut tree_hashes = tree.as_mut_slice();
121 let mut level_hashes = leaves.as_slice();
122
123 let mut pair = [0u8; OUT_LEN * 2];
124 while level_hashes.len() > 1 {
125 let num_pairs = level_hashes.len() / 2;
126 let parent_hashes;
127 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
130
131 for ([left_hash, right_hash], parent_hash) in
132 level_hashes.array_chunks().zip(parent_hashes.iter_mut())
133 {
134 pair[..OUT_LEN].copy_from_slice(left_hash);
135 pair[OUT_LEN..].copy_from_slice(right_hash);
136
137 parent_hash.write(hash_pair(left_hash, right_hash));
138 }
139
140 level_hashes = unsafe { parent_hashes.assume_init_ref() };
142 }
143 }
144
145 #[inline]
151 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
152 pub fn compute_root_only(leaves: &[[u8; OUT_LEN]; N]) -> [u8; OUT_LEN]
153 where
154 [(); N.ilog2() as usize + 1]:,
155 {
156 let mut stack = [[0u8; OUT_LEN]; N.ilog2() as usize + 1];
158
159 for (num_leaves, &hash) in leaves.iter().enumerate() {
160 let mut current = hash;
161
162 let lowest_active_levels = num_leaves.trailing_ones() as usize;
164 for item in stack.iter().take(lowest_active_levels) {
165 current = hash_pair(item, ¤t);
166 }
167
168 stack[lowest_active_levels] = current;
170 }
171
172 stack[N.ilog2() as usize]
173 }
174
175 #[inline]
177 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
178 pub fn root(&self) -> [u8; OUT_LEN] {
179 *self
180 .tree
181 .last()
182 .or(self.leaves.last())
183 .expect("There is always at least one leaf hash; qed")
184 }
185
186 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
188 pub fn all_proofs(
189 &self,
190 ) -> impl ExactSizeIterator<Item = [[u8; OUT_LEN]; N.ilog2() as usize]> + TrustedLen
191 where
192 [(); N.ilog2() as usize]:,
193 {
194 let iter = self.leaves.array_chunks().enumerate().flat_map(
195 |(pair_index, &[left_hash, right_hash])| {
196 let mut left_proof = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); N.ilog2() as usize];
197 left_proof[0].write(right_hash);
198
199 let left_proof = {
200 let (_, shared_proof) = left_proof.split_at_mut(1);
201
202 let mut tree_hashes = self.tree.as_slice();
203 let mut parent_position = pair_index;
204 let mut parent_level_size = N / 2;
205
206 for hash in shared_proof {
207 let parent_other_position = parent_position ^ 1;
214
215 let other_hash =
217 unsafe { tree_hashes.get_unchecked(parent_other_position) };
218 hash.write(*other_hash);
219 (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
220
221 parent_position /= 2;
222 parent_level_size /= 2;
223 }
224
225 unsafe { left_proof.transpose().assume_init() }
227 };
228
229 let mut right_proof = left_proof;
230 right_proof[0] = left_hash;
231
232 [left_proof, right_proof]
233 },
234 );
235
236 ProofsIterator { iter, len: N }
237 }
238
239 #[inline]
241 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
242 pub fn verify(
243 root: &[u8; OUT_LEN],
244 proof: &[[u8; OUT_LEN]; N.ilog2() as usize],
245 leaf_index: usize,
246 leaf: [u8; OUT_LEN],
247 ) -> bool
248 where
249 [(); N.ilog2() as usize]:,
250 {
251 if leaf_index >= N {
252 return false;
253 }
254
255 let mut computed_root = leaf;
256
257 let mut position = leaf_index;
258 for hash in proof {
259 computed_root = if position.is_multiple_of(2) {
260 hash_pair(&computed_root, hash)
261 } else {
262 hash_pair(hash, &computed_root)
263 };
264
265 position /= 2;
266 }
267
268 root == &computed_root
269 }
270}
271
272struct ProofsIterator<Iter> {
273 iter: Iter,
274 len: usize,
275}
276
277impl<Iter> Iterator for ProofsIterator<Iter>
278where
279 Iter: Iterator,
280{
281 type Item = Iter::Item;
282
283 #[inline(always)]
284 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
285 fn next(&mut self) -> Option<Self::Item> {
286 let item = self.iter.next();
287 self.len = self.len.saturating_sub(1);
288 item
289 }
290
291 #[inline(always)]
292 fn size_hint(&self) -> (usize, Option<usize>) {
293 (self.len, Some(self.len))
294 }
295
296 #[inline(always)]
297 fn count(self) -> usize
298 where
299 Self: Sized,
300 {
301 self.len
302 }
303
304 #[inline(always)]
305 fn last(self) -> Option<Self::Item>
306 where
307 Self: Sized,
308 {
309 self.iter.last()
310 }
311
312 #[inline(always)]
313 fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
314 self.len = self.len.saturating_sub(n);
315 self.iter.advance_by(n)
316 }
317
318 #[inline(always)]
319 fn nth(&mut self, n: usize) -> Option<Self::Item> {
320 self.len = self.len.saturating_sub(n.saturating_add(1));
321 self.iter.nth(n)
322 }
323}
324
325impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
326where
327 Iter: Iterator,
328{
329 #[inline(always)]
330 fn len(&self) -> usize {
331 self.len
332 }
333}
334
335unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}