ab_merkle_tree/
sparse.rs

1//! Sparse Merkle Tree and related data structures.
2//!
3//! Sparse Merkle Tree is essentially a huge Balanced Merkle Tree, where most of the leaves are
4//! empty. By "empty" here we mean `[0u8; 32]`. To optimize proofs and their verification, the
5//! hashing function is customized and returns `[0u8; 32]` when both left and right branch are
6//! `[0u8; 32]`, otherwise BLAKE3 hash is used like in a Balanced Merkle Tree.
7
8use crate::{OUT_LEN, hash_pair};
9use core::num::NonZeroU128;
10
11/// Ensuring only supported `NUM_BITS` can be specified for [`SparseMerkleTree`].
12///
13/// This is essentially a workaround for the current Rust type system constraints that do not allow
14/// a nicer way to do the same thing at compile time.
15pub const fn ensure_supported_bits(bits: u8) -> usize {
16    assert!(
17        bits <= 128,
18        "This Sparse Merkle Tree doesn't support more than 2^128 leaves"
19    );
20
21    assert!(
22        bits != 0,
23        "This Sparse Merkle Tree must have more than one leaf"
24    );
25
26    0
27}
28
29/// Sparse Merkle Tree Leaf
30#[derive(Debug)]
31pub enum Leaf<'a> {
32    // TODO: Batch of leaves for efficiently, especially with SIMD?
33    /// Leaf contains a value
34    Occupied {
35        /// Leaf value
36        leaf: &'a [u8; OUT_LEN],
37    },
38    /// Leaf contains a value (owned)
39    OccupiedOwned {
40        /// Leaf value
41        leaf: [u8; OUT_LEN],
42    },
43    /// Leaf is empty
44    Empty {
45        /// Number of consecutive empty leaves
46        skip_count: NonZeroU128,
47    },
48}
49
50impl<'a> From<&'a [u8; OUT_LEN]> for Leaf<'a> {
51    #[inline(always)]
52    fn from(leaf: &'a [u8; OUT_LEN]) -> Self {
53        Self::Occupied { leaf }
54    }
55}
56
57// TODO: A version that can hold intermediate nodes in memory, efficiently update leaves, etc.
58/// Sparse Merkle Tree variant that has hash-sized leaves, with most leaves being empty
59/// (have value `[0u8; 32]`).
60///
61/// In contrast to a proper Balanced Merkle Tree, constant `BITS` here specifies the max number of
62/// leaves hypothetically possible in a tree (2^BITS, often untractable), rather than the number of
63/// non-empty leaves actually present.
64#[derive(Debug)]
65pub struct SparseMerkleTree<const BITS: u8>;
66
67// TODO: Optimize by implementing SIMD-accelerated hashing of multiple values:
68//  https://github.com/BLAKE3-team/BLAKE3/issues/478
69impl<const BITS: u8> SparseMerkleTree<BITS>
70where
71    [(); ensure_supported_bits(BITS)]:,
72{
73    // TODO: Method that generates not only root, but also proof, like Unbalanced Merkle Tree
74    /// Compute Merkle Tree root.
75    ///
76    /// If provided iterator ends early, it means the rest of the leaves are empty.
77    ///
78    /// There must be no [`Leaf::Occupied`] for empty/unoccupied leaves or else they may result in
79    /// invalid root, [`Leaf::Empty`] must be used instead.
80    ///
81    /// Returns `None` if too many leaves were provided.
82    #[inline]
83    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
84    pub fn compute_root_only<'a, Iter>(leaves: Iter) -> Option<[u8; OUT_LEN]>
85    where
86        [(); BITS as usize + 1]:,
87        Iter: IntoIterator<Item = Leaf<'a>> + 'a,
88    {
89        // Stack of intermediate nodes per tree level
90        let mut stack = [[0u8; OUT_LEN]; BITS as usize + 1];
91        let mut processed_some = false;
92        let mut num_leaves = 0_u128;
93
94        for leaf in leaves {
95            if u32::from(BITS) < u128::BITS {
96                // How many leaves were processed so far
97                if num_leaves == 2u128.pow(u32::from(BITS)) {
98                    return None;
99                }
100            } else {
101                // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
102                // very end
103                if processed_some && num_leaves == 0 {
104                    return None;
105                }
106                processed_some = true;
107            }
108
109            let leaf = match leaf {
110                Leaf::Occupied { leaf } => *leaf,
111                Leaf::OccupiedOwned { leaf } => leaf,
112                Leaf::Empty { skip_count } => {
113                    num_leaves = Self::skip_leaves(
114                        &mut stack,
115                        &mut processed_some,
116                        num_leaves,
117                        skip_count.get(),
118                    )?;
119                    continue;
120                }
121            };
122
123            let mut current = leaf;
124
125            // Every bit set to `1` corresponds to an active Merkle Tree level
126            let lowest_active_levels = num_leaves.trailing_ones() as usize;
127            for item in stack.iter().take(lowest_active_levels) {
128                current = hash_pair(item, &current);
129            }
130
131            // Place the current hash at the first inactive level
132            // SAFETY: Number of lowest active levels corresponds to the number of inserted
133            // elements, which in turn is checked above to fit into 2^BITS, while `BITS`
134            // generic in turn ensured sufficient stack size
135            *unsafe { stack.get_unchecked_mut(lowest_active_levels) } = current;
136            // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
137            // doesn't fit into `u128` itself
138            num_leaves = num_leaves.wrapping_add(1);
139        }
140
141        if u32::from(BITS) < u128::BITS {
142            Self::skip_leaves(
143                &mut stack,
144                &mut processed_some,
145                num_leaves,
146                2u128.pow(u32::from(BITS)) - num_leaves,
147            )?;
148        } else if processed_some && num_leaves != 0 {
149            // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
150            // very end, so we reverse the mechanism here
151            Self::skip_leaves(
152                &mut stack,
153                &mut processed_some,
154                num_leaves,
155                0u128.wrapping_sub(num_leaves),
156            )?;
157        }
158
159        Some(stack[BITS as usize])
160    }
161
162    /// Returns updated number of leaves
163    #[inline]
164    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
165    fn skip_leaves(
166        stack: &mut [[u8; OUT_LEN]; BITS as usize + 1],
167        processed_some: &mut bool,
168        mut num_leaves: u128,
169        mut skip_count: u128,
170    ) -> Option<u128>
171    where
172        [(); BITS as usize + 1]:,
173    {
174        const ZERO: [u8; OUT_LEN] = [0; OUT_LEN];
175
176        if u32::from(BITS) < u128::BITS {
177            // How many leaves were processed so far
178            if num_leaves.checked_add(skip_count)? > 2u128.pow(u32::from(BITS)) {
179                return None;
180            }
181        } else {
182            // For `BITS == u128::BITS` `num_leaves` will wrap around back to zero right at the
183            // very end
184            let (overflow_amount, overflowed) = num_leaves.overflowing_add(skip_count);
185            if *processed_some && overflowed && overflow_amount > 0 {
186                return None;
187            }
188            *processed_some = true;
189        }
190
191        while skip_count > 0 {
192            // Find the largest aligned chunk to skip for the current state of the tree
193            let max_levels_to_skip = skip_count.ilog2().min(num_leaves.trailing_zeros());
194            let chunk_size = 1u128 << max_levels_to_skip;
195
196            let mut level = max_levels_to_skip;
197            let mut current = ZERO;
198            for item in stack.iter().skip(max_levels_to_skip as usize) {
199                // Check the active level for merging up the stack.
200                //
201                // `BITS == u128::BITS` condition is only added for better dead code elimination
202                // since that check is only relevant for 2^128 leaves case and nothing else.
203                if (u32::from(BITS) == u128::BITS && level == u128::BITS)
204                    || num_leaves & (1 << level) == 0
205                {
206                    // Level wasn't active before, stop here
207                    break;
208                }
209
210                // Hash together unless both are zero
211                if !(item == &ZERO && current == ZERO) {
212                    current = hash_pair(item, &current);
213                }
214
215                level += 1;
216            }
217            // SAFETY: Level is limited by the number of leaves, which in turn is checked above to
218            // fit into 2^BITS, while `BITS` generic in turn ensured sufficient stack size
219            *unsafe { stack.get_unchecked_mut(level as usize) } = current;
220
221            // Wrapping is needed for `BITS == u128::BITS`, where number of leaves narrowly
222            // doesn't fit into `u128` itself
223            num_leaves = num_leaves.wrapping_add(chunk_size);
224            skip_count -= chunk_size;
225        }
226
227        Some(num_leaves)
228    }
229
230    /// Verify previously generated proof.
231    ///
232    /// Leaf can either be leaf value for a leaf that is occupied or `[0; 32]` for a leaf that is
233    /// supposed to be empty.
234    #[inline]
235    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
236    pub fn verify(
237        root: &[u8; OUT_LEN],
238        proof: &[[u8; OUT_LEN]; BITS as usize],
239        leaf_index: u128,
240        leaf: [u8; OUT_LEN],
241    ) -> bool
242    where
243        [(); BITS as usize]:,
244    {
245        // For `BITS == u128::BITS` any index is valid by definition
246        if u32::from(BITS) < u128::BITS && leaf_index >= 2u128.pow(u32::from(BITS)) {
247            return false;
248        }
249
250        let mut computed_root = leaf;
251
252        let mut position = leaf_index;
253        for hash in proof {
254            computed_root = if position.is_multiple_of(2) {
255                hash_pair(&computed_root, hash)
256            } else {
257                hash_pair(hash, &computed_root)
258            };
259
260            position /= 2;
261        }
262
263        root == &computed_root
264    }
265}