ab_merkle_tree/
balanced_hashed.rs1#[cfg(feature = "alloc")]
2extern crate alloc;
3
4#[cfg(feature = "alloc")]
5use alloc::boxed::Box;
6use blake3::OUT_LEN;
7use core::iter::TrustedLen;
8use core::mem;
9use core::mem::MaybeUninit;
10
11#[inline(always)]
13pub const fn num_hashes(num_leaves_log_2: u32) -> usize {
14 2_usize.pow(num_leaves_log_2) - 1
15}
16
17#[inline(always)]
19pub const fn num_leaves(num_leaves_log_2: u32) -> usize {
20 2_usize.pow(num_leaves_log_2)
21}
22
23#[derive(Debug)]
35pub struct BalancedHashedMerkleTree<'a, const NUM_LEAVES_LOG_2: u32>
36where
37 [(); num_hashes(NUM_LEAVES_LOG_2)]:,
38{
39 leaf_hashes: &'a [[u8; OUT_LEN]],
40 tree: [[u8; OUT_LEN]; num_hashes(NUM_LEAVES_LOG_2)],
42}
43
44impl<'a, const NUM_LEAVES_LOG_2: u32> BalancedHashedMerkleTree<'a, NUM_LEAVES_LOG_2>
50where
51 [(); num_hashes(NUM_LEAVES_LOG_2)]:,
52{
53 pub fn new(leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)]) -> Self {
58 let mut tree = [MaybeUninit::<[u8; OUT_LEN]>::uninit(); num_hashes(NUM_LEAVES_LOG_2)];
59
60 Self::init_internal(leaf_hashes, &mut tree);
61
62 Self {
63 leaf_hashes,
64 tree: unsafe { tree.transpose().assume_init() },
66 }
67 }
68
69 pub fn new_in<'b>(
71 instance: &'b mut MaybeUninit<Self>,
72 leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)],
73 ) -> &'b mut Self {
74 let instance_ptr = instance.as_mut_ptr();
75 unsafe {
77 (&raw mut (*instance_ptr).leaf_hashes).write(leaf_hashes);
78 }
79 let tree = {
80 let tree_ptr = unsafe { &raw mut (*instance_ptr).tree };
82 unsafe {
84 tree_ptr
85 .cast::<[MaybeUninit<[u8; OUT_LEN]>; num_hashes(NUM_LEAVES_LOG_2)]>()
86 .as_mut_unchecked()
87 }
88 };
89
90 Self::init_internal(leaf_hashes, tree);
91
92 unsafe { instance.assume_init_mut() }
94 }
95
96 #[cfg(feature = "alloc")]
99 pub fn new_boxed(leaf_hashes: &'a [[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)]) -> Box<Self> {
100 let mut instance = Box::<Self>::new_uninit();
101
102 Self::new_in(&mut instance, leaf_hashes);
103
104 unsafe { instance.assume_init() }
106 }
107
108 fn init_internal(
109 leaf_hashes: &[[u8; OUT_LEN]; num_leaves(NUM_LEAVES_LOG_2)],
110 tree: &mut [MaybeUninit<[u8; OUT_LEN]>; num_hashes(NUM_LEAVES_LOG_2)],
111 ) {
112 let mut tree_hashes = tree.as_mut_slice();
113 let mut level_hashes = leaf_hashes.as_slice();
114
115 let mut pair = [0u8; OUT_LEN * 2];
116 while level_hashes.len() > 1 {
117 let num_pairs = level_hashes.len() / 2;
118 let parent_hashes;
119 (parent_hashes, tree_hashes) = unsafe { tree_hashes.split_at_mut_unchecked(num_pairs) };
122
123 for pair_index in 0..num_pairs {
124 let left_hash = unsafe { level_hashes.get_unchecked(pair_index * 2) };
126 let right_hash = unsafe { level_hashes.get_unchecked(pair_index * 2 + 1) };
128 let parent_hash = unsafe { parent_hashes.get_unchecked_mut(pair_index) };
130
131 pair[..OUT_LEN].copy_from_slice(left_hash);
132 pair[OUT_LEN..].copy_from_slice(right_hash);
133
134 parent_hash.write(*blake3::hash(&pair).as_bytes());
135 }
136
137 level_hashes = unsafe { parent_hashes.assume_init_ref() };
139 }
140 }
141
142 #[inline]
146 pub fn root(&self) -> [u8; OUT_LEN] {
147 *self
148 .tree
149 .last()
150 .or(self.leaf_hashes.last())
151 .expect("There is always at least one leaf hash; qed")
152 }
153
154 pub fn all_proofs(
156 &self,
157 ) -> impl ExactSizeIterator<Item = [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize]> + TrustedLen
158 where
159 [(); OUT_LEN * NUM_LEAVES_LOG_2 as usize]:,
160 {
161 let iter = self.leaf_hashes.array_chunks().enumerate().flat_map(
162 |(pair_index, &[left_hash, right_hash])| {
163 let mut left_proof =
164 [MaybeUninit::<[u8; OUT_LEN]>::uninit(); NUM_LEAVES_LOG_2 as usize];
165 left_proof[0].write(right_hash);
166
167 let left_proof = {
168 let (_, shared_proof) = left_proof.split_at_mut(1);
169
170 let mut tree_hashes = self.tree.as_slice();
171 let mut parent_position = pair_index;
172 let mut parent_level_size = num_leaves(NUM_LEAVES_LOG_2) / 2;
173
174 for hash in shared_proof {
175 let parent_other_position = if parent_position % 2 == 0 {
176 parent_position + 1
177 } else {
178 parent_position - 1
179 };
180 let other_hash =
182 unsafe { tree_hashes.get_unchecked(parent_other_position) };
183 hash.write(*other_hash);
184 (_, tree_hashes) = tree_hashes.split_at(parent_level_size);
185
186 parent_position /= 2;
187 parent_level_size /= 2;
188 }
189
190 unsafe { left_proof.transpose().assume_init() }
192 };
193
194 let mut right_proof = left_proof;
195 right_proof[0] = left_hash;
196
197 let left_proof = unsafe {
201 mem::transmute_copy::<
202 [[u8; OUT_LEN]; NUM_LEAVES_LOG_2 as usize],
203 [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
204 >(&left_proof)
205 };
206 let right_proof = unsafe {
207 mem::transmute_copy::<
208 [[u8; OUT_LEN]; NUM_LEAVES_LOG_2 as usize],
209 [u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
210 >(&right_proof)
211 };
212 [left_proof, right_proof]
213 },
214 );
215
216 ProofsIterator {
217 iter,
218 len: num_leaves(NUM_LEAVES_LOG_2),
219 }
220 }
221
222 #[inline]
224 pub fn verify(
225 root: &[u8; OUT_LEN],
226 proof: &[u8; OUT_LEN * NUM_LEAVES_LOG_2 as usize],
227 leaf_index: usize,
228 leaf_hash: [u8; OUT_LEN],
229 ) -> bool
230 where
231 [(); OUT_LEN * NUM_LEAVES_LOG_2 as usize]:,
232 {
233 if leaf_index >= num_leaves(NUM_LEAVES_LOG_2) {
234 return false;
235 }
236
237 let mut computed_root = leaf_hash;
238
239 let mut position = leaf_index;
240 let mut pair = [0u8; OUT_LEN * 2];
241 for hash in proof.array_chunks::<OUT_LEN>() {
242 if position % 2 == 0 {
243 pair[..OUT_LEN].copy_from_slice(&computed_root);
244 pair[OUT_LEN..].copy_from_slice(hash);
245 } else {
246 pair[..OUT_LEN].copy_from_slice(hash);
247 pair[OUT_LEN..].copy_from_slice(&computed_root);
248 }
249
250 position /= 2;
251 computed_root = *blake3::hash(&pair).as_bytes();
252 }
253
254 root == &computed_root
255 }
256}
257
258struct ProofsIterator<Iter> {
259 iter: Iter,
260 len: usize,
261}
262
263impl<Iter> Iterator for ProofsIterator<Iter>
264where
265 Iter: Iterator,
266{
267 type Item = Iter::Item;
268
269 #[inline(always)]
270 fn next(&mut self) -> Option<Self::Item> {
271 self.iter.next()
272 }
273
274 #[inline(always)]
275 fn size_hint(&self) -> (usize, Option<usize>) {
276 (self.len, Some(self.len))
277 }
278
279 #[inline(always)]
280 fn count(self) -> usize
281 where
282 Self: Sized,
283 {
284 self.len
285 }
286}
287
288impl<Iter> ExactSizeIterator for ProofsIterator<Iter>
289where
290 Iter: Iterator,
291{
292 #[inline(always)]
293 fn len(&self) -> usize {
294 self.len
295 }
296}
297
298unsafe impl<Iter> TrustedLen for ProofsIterator<Iter> where Iter: Iterator {}