ab_merkle_tree/sparse.rs
1//! Sparse Merkle Tree and related data structures.
2//!
3//! Sparse Merkle Tree is essentially a huge Balanced Merkle Tree, where most of the leaves are
4//! empty. By "empty" here we mean `[0u8; 32]`. To optimize proofs and their verification, the
5//! hashing function is customized and returns `[0u8; 32]` when both left and right branch are
6//! `[0u8; 32]`, otherwise BLAKE3 hash is used like in a Balanced Merkle Tree.
7
8use crate::{OUT_LEN, hash_pair};
9use core::num::NonZeroU128;
10
11/// Ensuring only supported `NUM_BITS` can be specified for [`SparseMerkleTree`].
12///
13/// This is essentially a workaround for the current Rust type system constraints that do not allow
14/// a nicer way to do the same thing at compile time.
15pub const fn ensure_supported_bits(bits: u8) -> usize {
16 assert!(
17 bits <= 128,
18 "This Sparse Merkle Tree doesn't support more than 2^128 leaves"
19 );
20
21 assert!(
22 bits != 0,
23 "This Sparse Merkle Tree must have more than one leaf"
24 );
25
26 0
27}
28
29/// Sparse Merkle Tree Leaf
30#[derive(Debug)]
31pub enum Leaf<'a> {
32 // TODO: Batch of leaves for efficiently, especially with SIMD?
33 /// Leaf contains a value
34 Occupied {
35 /// Leaf value
36 leaf: &'a [u8; OUT_LEN],
37 },
38 /// Leaf contains a value (owned)
39 OccupiedOwned {
40 /// Leaf value
41 leaf: [u8; OUT_LEN],
42 },
43 /// Leaf is empty
44 Empty {
45 /// Number of consecutive empty leaves
46 skip_count: NonZeroU128,
47 },
48}
49
50impl<'a> From<&'a [u8; OUT_LEN]> for Leaf<'a> {
51 #[inline(always)]
52 fn from(leaf: &'a [u8; OUT_LEN]) -> Self {
53 Self::Occupied { leaf }
54 }
55}
56
57// TODO: A version that can hold intermediate nodes in memory, efficiently update leaves, etc.
58/// Sparse Merkle Tree variant that has hash-sized leaves, with most leaves being empty
59/// (have value `[0u8; 32]`).
60///
61/// In contrast to a proper Balanced Merkle Tree, constant `BITS` here specifies the max number of
62/// leaves hypothetically possible in a tree (2^BITS, often untractable), rather than the number of
63/// non-empty leaves actually present.
64#[derive(Debug)]
65pub struct SparseMerkleTree<const BITS: u8>;
66
67// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
68// https://github.com/BLAKE3-team/BLAKE3/issues/478
69impl<const BITS: u8> SparseMerkleTree<BITS>
70where
71 [(); ensure_supported_bits(BITS)]:,
72{
73 // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
74 /// Compute Merkle Tree root.
75 ///
76 /// If provided iterator ends early, it means the rest of the leaves are empty.
77 ///
78 /// There must be no [`Leaf::Occupied`] for empty/unoccupied leaves or else they may result in
79 /// invalid root, [`Leaf::Empty`] must be used instead.
80 ///
81 /// Returns `None` if too many leaves were provided.
82 #[inline]
83 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
84 pub fn compute_root_only<'a, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
85 where
86 [(); BITS as usize + 1]:,
87 Iter: IntoIterator<Item = Leaf<'a>> + 'a,
88 {
89 // Stack of intermediate nodes per tree level
90 let mut stack = [[0u8; OUT_LEN]; BITS as usize + 1];
91 let mut processed_some = false;
92 let mut num_leaves = 0_u128;
93
94 for leaf in leaves {
95 if u32::from(BITS) < u128::BITS {
96 // How many leaves were processed so far
97 if num_leaves == 2u128.pow(u32::from(BITS)) {
98 return None;
99 }
100 } else {
101 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
102 // very end
103 if processed_some && num_leaves == 0 {
104 return None;
105 }
106 processed_some = true;
107 }
108
109 let leaf = match leaf {
110 Leaf::Occupied { leaf } => *leaf,
111 Leaf::OccupiedOwned { leaf } => leaf,
112 Leaf::Empty { skip_count } => {
113 num_leaves = Self::skip_leaves(
114 &mut stack,
115 &mut processed_some,
116 num_leaves,
117 skip_count.get(),
118 )?;
119 continue;
120 }
121 };
122
123 let mut current = leaf;
124
125 // Every bit set to `1` corresponds to an active Merkle Tree level
126 let lowest_active_levels = num_leaves.trailing_ones() as usize;
127 for item in stack.iter().take(lowest_active_levels) {
128 current = hash_pair(item, ¤t);
129 }
130
131 // Place the current hash at the first inactive level
132 // SAFETY: Number of lowest active levels corresponds to the number of inserted
133 // elements, which in turn is checked above to fit into 2^BITS, while `BITS`
134 // generic in turn ensured sufficient stack size
135 *unsafe { stack.get_unchecked_mut(lowest_active_levels) } = current;
136 // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
137 // doesn't fit into `u128` itself
138 num_leaves = num_leaves.wrapping_add(1);
139 }
140
141 if u32::from(BITS) < u128::BITS {
142 Self::skip_leaves(
143 &mut stack,
144 &mut processed_some,
145 num_leaves,
146 2u128.pow(u32::from(BITS)) - num_leaves,
147 )?;
148 } else if processed_some && num_leaves != 0 {
149 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
150 // very end, so we reverse the mechanism here
151 Self::skip_leaves(
152 &mut stack,
153 &mut processed_some,
154 num_leaves,
155 0u128.wrapping_sub(num_leaves),
156 )?;
157 }
158
159 Some(stack[BITS as usize])
160 }
161
162 /// Returns updated number of leaves
163 #[inline]
164 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
165 fn skip_leaves(
166 stack: &mut [[u8; OUT_LEN]; BITS as usize + 1],
167 processed_some: &mut bool,
168 mut num_leaves: u128,
169 mut skip_count: u128,
170 ) -> Option<u128>
171 where
172 [(); BITS as usize + 1]:,
173 {
174 const ZERO: [u8; OUT_LEN] = [0; OUT_LEN];
175
176 if u32::from(BITS) < u128::BITS {
177 // How many leaves were processed so far
178 if num_leaves.checked_add(skip_count)? > 2u128.pow(u32::from(BITS)) {
179 return None;
180 }
181 } else {
182 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
183 // very end
184 let (overflow_amount, overflowed) = num_leaves.overflowing_add(skip_count);
185 if *processed_some && overflowed && overflow_amount > 0 {
186 return None;
187 }
188 *processed_some = true;
189 }
190
191 while skip_count > 0 {
192 // Find the largest aligned chunk to skip for the current state of the tree
193 let max_levels_to_skip = skip_count.ilog2().min(num_leaves.trailing_zeros());
194 let chunk_size = 1u128 << max_levels_to_skip;
195
196 let mut level = max_levels_to_skip;
197 let mut current = ZERO;
198 for item in stack.iter().skip(max_levels_to_skip as usize) {
199 // Check the active level for merging up the stack.
200 //
201 // `BITS == u128::BITS` condition is only added for better dead code elimination
202 // since that check is only relevant for 2^128 leaves case and nothing else.
203 if (u32::from(BITS) == u128::BITS && level == u128::BITS)
204 || num_leaves & (1 << level) == 0
205 {
206 // Level wasn't active before, stop here
207 break;
208 }
209
210 // Hash together unless both are zero
211 if !(item == &ZERO && current == ZERO) {
212 current = hash_pair(item, ¤t);
213 }
214
215 level += 1;
216 }
217 // SAFETY: Level is limited by the number of leaves, which in turn is checked above to
218 // fit into 2^BITS, while `BITS` generic in turn ensured sufficient stack size
219 *unsafe { stack.get_unchecked_mut(level as usize) } = current;
220
221 // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
222 // doesn't fit into `u128` itself
223 num_leaves = num_leaves.wrapping_add(chunk_size);
224 skip_count -= chunk_size;
225 }
226
227 Some(num_leaves)
228 }
229
230 /// Verify previously generated proof.
231 ///
232 /// Leaf can either be leaf value for a leaf that is occupied or `[0; 32]` for a leaf that is
233 /// supposed to be empty.
234 #[inline]
235 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
236 pub fn verify(
237 root: &[u8; OUT_LEN],
238 proof: &[[u8; OUT_LEN]; BITS as usize],
239 leaf_index: u128,
240 leaf: [u8; OUT_LEN],
241 ) -> bool
242 where
243 [(); BITS as usize]:,
244 {
245 // For `BITS == u128::BITS` any index is valid by definition
246 if u32::from(BITS) < u128::BITS && leaf_index >= 2u128.pow(u32::from(BITS)) {
247 return false;
248 }
249
250 let mut computed_root = leaf;
251
252 let mut position = leaf_index;
253 for hash in proof {
254 computed_root = if position.is_multiple_of(2) {
255 hash_pair(&computed_root, hash)
256 } else {
257 hash_pair(hash, &computed_root)
258 };
259
260 position /= 2;
261 }
262
263 root == &computed_root
264 }
265}