ab_sparse_merkle_tree/
lib.rs

1//! Sparse Merkle Tree and related data structures.
2//!
3//! Sparse Merkle Tree is essentially a huge Balanced Merkle Tree, where most of the leaves are
4//! empty. By "empty" here we mean `[0u8; 32]`. In order to optimize proofs and their verification,
5//! hashing function is customized and returns `[0u8; 32]` when both left and right branch are
6//! `[0u8; 32]`, otherwise BLAKE3 hash is used like in a Balanced Merkle Tree.
7
8#![expect(incomplete_features, reason = "generic_const_exprs")]
9#![feature(generic_const_exprs)]
10#![no_std]
11
12use ab_blake3::{KEY_LEN, OUT_LEN};
13
14/// Used as a key in keyed blake3 hash for inner nodes of Merkle Trees.
15///
16/// This value is a blake3 hash of as string `merkle-tree-inner-node`.
17pub const INNER_NODE_DOMAIN_SEPARATOR: [u8; KEY_LEN] =
18    ab_blake3::const_hash(b"merkle-tree-inner-node");
19
20/// Helper function to hash two nodes together using [`ab_blake3::single_block_keyed_hash()`] and
21/// [`INNER_NODE_DOMAIN_SEPARATOR`]
22#[inline(always)]
23#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
24pub fn hash_pair(left: &[u8; OUT_LEN], right: &[u8; OUT_LEN]) -> [u8; OUT_LEN] {
25    let mut pair = [0u8; OUT_LEN * 2];
26    pair[..OUT_LEN].copy_from_slice(left);
27    pair[OUT_LEN..].copy_from_slice(right);
28
29    ab_blake3::single_block_keyed_hash(&INNER_NODE_DOMAIN_SEPARATOR, &pair)
30        .expect("Exactly one block worth of data; qed")
31}
32
33/// Ensuring only supported `NUM_BITS` can be specified for [`SparseMerkleTree`].
34///
35/// This is essentially a workaround for the current Rust type system constraints that do not allow
36/// a nicer way to do the same thing at compile time.
37pub const fn ensure_supported_bits(bits: u8) -> usize {
38    assert!(
39        bits <= 128,
40        "This Sparse Merkle Tree doesn't support more than 2^128 leaves"
41    );
42
43    assert!(
44        bits != 0,
45        "This Sparse Merkle Tree must have more than one leaf"
46    );
47
48    0
49}
50
51/// Sparse Merkle Tree Leaf
52#[derive(Debug)]
53pub enum Leaf<'a> {
54    // TODO: Batch of leaves for efficiently, especially with SIMD?
55    /// Leaf is occupied by a value
56    Occupied {
57        /// Leaf value
58        leaf: &'a [u8; OUT_LEN],
59    },
60    /// Leaf is empty
61    Empty {
62        /// Number of consecutive empty leaves
63        skip_count: u128,
64    },
65}
66
67impl<'a> From<&'a [u8; OUT_LEN]> for Leaf<'a> {
68    #[inline(always)]
69    fn from(leaf: &'a [u8; OUT_LEN]) -> Self {
70        Self::Occupied { leaf }
71    }
72}
73
74// TODO: A version that can hold intermediate nodes in memory, efficiently update leaves, etc.
75/// Sparse Merkle Tree variant that has hash-sized leaves, with most leaves being empty
76/// (have value `[0u8; 32]`).
77///
78/// In contrast to a proper Balanced Merkle Tree, constant `BITS` here specifies the max number of
79/// leaves hypothetically possible in a tree (2^BITS, often untractable), rather than the number of
80/// non-empty leaves actually present.
81#[derive(Debug)]
82pub struct SparseMerkleTree<const BITS: u8>;
83
84// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
85//  https://github.com/BLAKE3-team/BLAKE3/issues/478
86impl<const BITS: u8> SparseMerkleTree<BITS>
87where
88    [(); ensure_supported_bits(BITS)]:,
89{
90    // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
91    /// Compute Merkle Tree root.
92    ///
93    /// If provided iterator ends early, it means the rest of the leaves are empty.
94    ///
95    /// There must be no [`Leaf::Occupied`] for empty/unoccupied leaves or else they may result in
96    /// invalid root, [`Leaf::Empty`] must be used instead.
97    ///
98    /// Returns `None` if too many leaves were provided.
99    #[inline]
100    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
101    pub fn compute_root_only<'a, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
102    where
103        [(); BITS as usize + 1]:,
104        Iter: IntoIterator<Item = Leaf<'a>> + 'a,
105    {
106        // Stack of intermediate nodes per tree level
107        let mut stack = [[0u8; OUT_LEN]; BITS as usize + 1];
108        let mut processed_some = false;
109        let mut num_leaves = 0_u128;
110
111        for leaf in leaves {
112            if u32::from(BITS) < u128::BITS {
113                // How many leaves were processed so far
114                if num_leaves == 2u128.pow(u32::from(BITS)) {
115                    return None;
116                }
117            } else {
118                // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
119                // very end
120                if processed_some && num_leaves == 0 {
121                    return None;
122                }
123                processed_some = true;
124            }
125
126            match leaf {
127                Leaf::Occupied { leaf } => {
128                    let mut current = *leaf;
129
130                    // Every bit set to `1` corresponds to an active Merkle Tree level
131                    let lowest_active_levels = num_leaves.trailing_ones() as usize;
132                    for item in stack.iter().take(lowest_active_levels) {
133                        current = hash_pair(item, &current);
134                    }
135
136                    // Place the current hash at the first inactive level
137                    // SAFETY: Number of lowest active levels corresponds to the number of inserted
138                    // elements, which in turn is checked above to fit into 2^BITS, while `BITS`
139                    // generic in turn ensured sufficient stack size
140                    *unsafe { stack.get_unchecked_mut(lowest_active_levels) } = current;
141                    // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
142                    // doesn't fit into `u128` itself
143                    num_leaves = num_leaves.wrapping_add(1);
144                }
145                Leaf::Empty { skip_count } => {
146                    num_leaves =
147                        Self::skip_leaves(&mut stack, &mut processed_some, num_leaves, skip_count)?;
148                }
149            }
150        }
151
152        if u32::from(BITS) < u128::BITS {
153            Self::skip_leaves(
154                &mut stack,
155                &mut processed_some,
156                num_leaves,
157                2u128.pow(u32::from(BITS)) - num_leaves,
158            )?;
159        } else if processed_some && num_leaves != 0 {
160            // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
161            // very end, so we reverse the mechanism here
162            Self::skip_leaves(
163                &mut stack,
164                &mut processed_some,
165                num_leaves,
166                0u128.wrapping_sub(num_leaves),
167            )?;
168        }
169
170        Some(stack[BITS as usize])
171    }
172
173    /// Returns updated number of leaves
174    #[inline]
175    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
176    fn skip_leaves(
177        stack: &mut [[u8; OUT_LEN]; BITS as usize + 1],
178        processed_some: &mut bool,
179        mut num_leaves: u128,
180        mut skip_count: u128,
181    ) -> Option<u128>
182    where
183        [(); BITS as usize + 1]:,
184    {
185        const ZERO: [u8; OUT_LEN] = [0; OUT_LEN];
186
187        if u32::from(BITS) < u128::BITS {
188            // How many leaves were processed so far
189            if num_leaves.checked_add(skip_count)? > 2u128.pow(u32::from(BITS)) {
190                return None;
191            }
192        } else {
193            // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
194            // very end
195            let (overflow_amount, overflowed) = num_leaves.overflowing_add(skip_count);
196            if *processed_some && overflowed && overflow_amount > 0 {
197                return None;
198            }
199            *processed_some = true;
200        }
201
202        while skip_count > 0 {
203            // Find the largest aligned chunk to skip for the current state of the tree
204            let max_levels_to_skip = skip_count.ilog2().min(num_leaves.trailing_zeros());
205            let chunk_size = 1u128 << max_levels_to_skip;
206
207            let mut level = max_levels_to_skip;
208            let mut current = ZERO;
209            for item in stack.iter().skip(max_levels_to_skip as usize) {
210                // Check active level for merging up the stack.
211                //
212                // `BITS == u128::BITS` condition is only added for better dead code elimination
213                // since that check is only relevant for 2^128 leaves case and nothing else.
214                if (u32::from(BITS) == u128::BITS && level == u128::BITS)
215                    || num_leaves & (1 << level) == 0
216                {
217                    // Level wasn't active before, stop here
218                    break;
219                }
220
221                // Hash together unless both are zero
222                if !(item == &ZERO && current == ZERO) {
223                    current = hash_pair(item, &current);
224                }
225
226                level += 1;
227            }
228            // SAFETY: Level is limited by the number of leaves, which in turn is checked above to
229            // fit into 2^BITS, while `BITS` generic in turn ensured sufficient stack size
230            *unsafe { stack.get_unchecked_mut(level as usize) } = current;
231
232            // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
233            // doesn't fit into `u128` itself
234            num_leaves = num_leaves.wrapping_add(chunk_size);
235            skip_count -= chunk_size;
236        }
237
238        Some(num_leaves)
239    }
240
241    /// Verify previously generated proof.
242    ///
243    /// Leaf can either be leaf value for a leaf that is occupied or `[0; 32]` for a leaf that is
244    /// supposed to be empty.
245    #[inline]
246    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
247    pub fn verify(
248        root: &[u8; OUT_LEN],
249        proof: &[[u8; OUT_LEN]; BITS as usize],
250        leaf_index: u128,
251        leaf: [u8; OUT_LEN],
252    ) -> bool
253    where
254        [(); BITS as usize]:,
255    {
256        // For `BITS == u128::BITS` any index is valid by definition
257        if u32::from(BITS) < u128::BITS && leaf_index >= 2u128.pow(u32::from(BITS)) {
258            return false;
259        }
260
261        let mut computed_root = leaf;
262
263        let mut position = leaf_index;
264        for hash in proof {
265            computed_root = if position.is_multiple_of(2) {
266                hash_pair(&computed_root, hash)
267            } else {
268                hash_pair(hash, &computed_root)
269            };
270
271            position /= 2;
272        }
273
274        root == &computed_root
275    }
276}