ab_sparse_merkle_tree/lib.rs
1//! Sparse Merkle Tree and related data structures.
2//!
3//! Sparse Merkle Tree is essentially a huge Balanced Merkle Tree, where most of the leaves are
4//! empty. By "empty" here we mean `[0u8; 32]`. In order to optimize proofs and their verification,
5//! hashing function is customized and returns `[0u8; 32]` when both left and right branch are
6//! `[0u8; 32]`, otherwise BLAKE3 hash is used like in a Balanced Merkle Tree.
7
8#![expect(incomplete_features, reason = "generic_const_exprs")]
9#![feature(generic_const_exprs)]
10#![no_std]
11
12use ab_blake3::{KEY_LEN, OUT_LEN};
13
14/// Used as a key in keyed blake3 hash for inner nodes of Merkle Trees.
15///
16/// This value is a blake3 hash of as string `merkle-tree-inner-node`.
17pub const INNER_NODE_DOMAIN_SEPARATOR: [u8; KEY_LEN] =
18 ab_blake3::const_hash(b"merkle-tree-inner-node");
19
20/// Helper function to hash two nodes together using [`ab_blake3::single_block_keyed_hash()`] and
21/// [`INNER_NODE_DOMAIN_SEPARATOR`]
22#[inline(always)]
23#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
24pub fn hash_pair(left: &[u8; OUT_LEN], right: &[u8; OUT_LEN]) -> [u8; OUT_LEN] {
25 let mut pair = [0u8; OUT_LEN * 2];
26 pair[..OUT_LEN].copy_from_slice(left);
27 pair[OUT_LEN..].copy_from_slice(right);
28
29 ab_blake3::single_block_keyed_hash(&INNER_NODE_DOMAIN_SEPARATOR, &pair)
30 .expect("Exactly one block worth of data; qed")
31}
32
33/// Ensuring only supported `NUM_BITS` can be specified for [`SparseMerkleTree`].
34///
35/// This is essentially a workaround for the current Rust type system constraints that do not allow
36/// a nicer way to do the same thing at compile time.
37pub const fn ensure_supported_bits(bits: u8) -> usize {
38 assert!(
39 bits <= 128,
40 "This Sparse Merkle Tree doesn't support more than 2^128 leaves"
41 );
42
43 assert!(
44 bits != 0,
45 "This Sparse Merkle Tree must have more than one leaf"
46 );
47
48 0
49}
50
51/// Sparse Merkle Tree Leaf
52#[derive(Debug)]
53pub enum Leaf<'a> {
54 // TODO: Batch of leaves for efficiently, especially with SIMD?
55 /// Leaf is occupied by a value
56 Occupied {
57 /// Leaf value
58 leaf: &'a [u8; OUT_LEN],
59 },
60 /// Leaf is empty
61 Empty {
62 /// Number of consecutive empty leaves
63 skip_count: u128,
64 },
65}
66
67impl<'a> From<&'a [u8; OUT_LEN]> for Leaf<'a> {
68 #[inline(always)]
69 fn from(leaf: &'a [u8; OUT_LEN]) -> Self {
70 Self::Occupied { leaf }
71 }
72}
73
74// TODO: A version that can hold intermediate nodes in memory, efficiently update leaves, etc.
75/// Sparse Merkle Tree variant that has hash-sized leaves, with most leaves being empty
76/// (have value `[0u8; 32]`).
77///
78/// In contrast to a proper Balanced Merkle Tree, constant `BITS` here specifies the max number of
79/// leaves hypothetically possible in a tree (2^BITS, often untractable), rather than the number of
80/// non-empty leaves actually present.
81#[derive(Debug)]
82pub struct SparseMerkleTree<const BITS: u8>;
83
84// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
85// https://github.com/BLAKE3-team/BLAKE3/issues/478
86impl<const BITS: u8> SparseMerkleTree<BITS>
87where
88 [(); ensure_supported_bits(BITS)]:,
89{
90 // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
91 /// Compute Merkle Tree root.
92 ///
93 /// If provided iterator ends early, it means the rest of the leaves are empty.
94 ///
95 /// There must be no [`Leaf::Occupied`] for empty/unoccupied leaves or else they may result in
96 /// invalid root, [`Leaf::Empty`] must be used instead.
97 ///
98 /// Returns `None` if too many leaves were provided.
99 #[inline]
100 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
101 pub fn compute_root_only<'a, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
102 where
103 [(); BITS as usize + 1]:,
104 Iter: IntoIterator<Item = Leaf<'a>> + 'a,
105 {
106 // Stack of intermediate nodes per tree level
107 let mut stack = [[0u8; OUT_LEN]; BITS as usize + 1];
108 let mut processed_some = false;
109 let mut num_leaves = 0_u128;
110
111 for leaf in leaves {
112 if u32::from(BITS) < u128::BITS {
113 // How many leaves were processed so far
114 if num_leaves == 2u128.pow(u32::from(BITS)) {
115 return None;
116 }
117 } else {
118 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
119 // very end
120 if processed_some && num_leaves == 0 {
121 return None;
122 }
123 processed_some = true;
124 }
125
126 match leaf {
127 Leaf::Occupied { leaf } => {
128 let mut current = *leaf;
129
130 // Every bit set to `1` corresponds to an active Merkle Tree level
131 let lowest_active_levels = num_leaves.trailing_ones() as usize;
132 for item in stack.iter().take(lowest_active_levels) {
133 current = hash_pair(item, ¤t);
134 }
135
136 // Place the current hash at the first inactive level
137 // SAFETY: Number of lowest active levels corresponds to the number of inserted
138 // elements, which in turn is checked above to fit into 2^BITS, while `BITS`
139 // generic in turn ensured sufficient stack size
140 *unsafe { stack.get_unchecked_mut(lowest_active_levels) } = current;
141 // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
142 // doesn't fit into `u128` itself
143 num_leaves = num_leaves.wrapping_add(1);
144 }
145 Leaf::Empty { skip_count } => {
146 num_leaves =
147 Self::skip_leaves(&mut stack, &mut processed_some, num_leaves, skip_count)?;
148 }
149 }
150 }
151
152 if u32::from(BITS) < u128::BITS {
153 Self::skip_leaves(
154 &mut stack,
155 &mut processed_some,
156 num_leaves,
157 2u128.pow(u32::from(BITS)) - num_leaves,
158 )?;
159 } else if processed_some && num_leaves != 0 {
160 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
161 // very end, so we reverse the mechanism here
162 Self::skip_leaves(
163 &mut stack,
164 &mut processed_some,
165 num_leaves,
166 0u128.wrapping_sub(num_leaves),
167 )?;
168 }
169
170 Some(stack[BITS as usize])
171 }
172
173 /// Returns updated number of leaves
174 #[inline]
175 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
176 fn skip_leaves(
177 stack: &mut [[u8; OUT_LEN]; BITS as usize + 1],
178 processed_some: &mut bool,
179 mut num_leaves: u128,
180 mut skip_count: u128,
181 ) -> Option<u128>
182 where
183 [(); BITS as usize + 1]:,
184 {
185 const ZERO: [u8; OUT_LEN] = [0; OUT_LEN];
186
187 if u32::from(BITS) < u128::BITS {
188 // How many leaves were processed so far
189 if num_leaves.checked_add(skip_count)? > 2u128.pow(u32::from(BITS)) {
190 return None;
191 }
192 } else {
193 // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
194 // very end
195 let (overflow_amount, overflowed) = num_leaves.overflowing_add(skip_count);
196 if *processed_some && overflowed && overflow_amount > 0 {
197 return None;
198 }
199 *processed_some = true;
200 }
201
202 while skip_count > 0 {
203 // Find the largest aligned chunk to skip for the current state of the tree
204 let max_levels_to_skip = skip_count.ilog2().min(num_leaves.trailing_zeros());
205 let chunk_size = 1u128 << max_levels_to_skip;
206
207 let mut level = max_levels_to_skip;
208 let mut current = ZERO;
209 for item in stack.iter().skip(max_levels_to_skip as usize) {
210 // Check active level for merging up the stack.
211 //
212 // `BITS == u128::BITS` condition is only added for better dead code elimination
213 // since that check is only relevant for 2^128 leaves case and nothing else.
214 if (u32::from(BITS) == u128::BITS && level == u128::BITS)
215 || num_leaves & (1 << level) == 0
216 {
217 // Level wasn't active before, stop here
218 break;
219 }
220
221 // Hash together unless both are zero
222 if !(item == &ZERO && current == ZERO) {
223 current = hash_pair(item, ¤t);
224 }
225
226 level += 1;
227 }
228 // SAFETY: Level is limited by the number of leaves, which in turn is checked above to
229 // fit into 2^BITS, while `BITS` generic in turn ensured sufficient stack size
230 *unsafe { stack.get_unchecked_mut(level as usize) } = current;
231
232 // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
233 // doesn't fit into `u128` itself
234 num_leaves = num_leaves.wrapping_add(chunk_size);
235 skip_count -= chunk_size;
236 }
237
238 Some(num_leaves)
239 }
240
241 /// Verify previously generated proof.
242 ///
243 /// Leaf can either be leaf value for a leaf that is occupied or `[0; 32]` for a leaf that is
244 /// supposed to be empty.
245 #[inline]
246 #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
247 pub fn verify(
248 root: &[u8; OUT_LEN],
249 proof: &[[u8; OUT_LEN]; BITS as usize],
250 leaf_index: u128,
251 leaf: [u8; OUT_LEN],
252 ) -> bool
253 where
254 [(); BITS as usize]:,
255 {
256 // For `BITS == u128::BITS` any index is valid by definition
257 if u32::from(BITS) < u128::BITS && leaf_index >= 2u128.pow(u32::from(BITS)) {
258 return false;
259 }
260
261 let mut computed_root = leaf;
262
263 let mut position = leaf_index;
264 for hash in proof {
265 computed_root = if position.is_multiple_of(2) {
266 hash_pair(&computed_root, hash)
267 } else {
268 hash_pair(hash, &computed_root)
269 };
270
271 position /= 2;
272 }
273
274 root == &computed_root
275 }
276}