ab_core_primitives/
shard.rs

1//! Shard-related primitives
2
3use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21/// A kind of shard.
22///
23/// Schematically, the hierarchy of shards is as follows:
24/// ```text
25///                          Beacon chain
26///                          /          \
27///      Intermediate shard 1            Intermediate shard 2
28///              /  \                            /  \
29/// Leaf shard 11   Leaf shard 12   Leaf shard 22   Leaf shard 22
30/// ```
31///
32/// Beacon chain has index 0, intermediate shards indices 1..=1023. Leaf shards share the same least
33/// significant 10 bits as their respective intermediate shards, so leaf shards of intermediate
34/// shard 1 have indices like 1025,2049,3097,etc.
35///
36/// If represented as least significant bits first (as it will be in the formatted address):
37/// ```text
38/// Intermediate shard bits
39///     \            /
40///      1000_0000_0001_0000_0000
41///                 /            \
42///                Leaf shard bits
43/// ```
44///
45/// Note that shards that have 10 least significant bits set to 0 (corresponds to the beacon chain)
46/// are not leaf shards, in fact, they are not even physical shards that have nodes in general. The
47/// meaning of these shards is TBD, currently they are called "phantom" shards and may end up
48/// containing some system contracts with special meaning, but no blocks will ever exist for such
49/// shards.
50#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52    /// Beacon chain shard
53    BeaconChain,
54    /// Intermediate shard directly below the beacon chain that has child shards
55    IntermediateShard,
56    /// Leaf shard, which doesn't have child shards
57    LeafShard,
58    /// TODO
59    Phantom,
60}
61
62impl ShardKind {
63    /// Try to convert to real shard kind.
64    ///
65    /// Returns `None` for phantom shard.
66    #[inline(always)]
67    pub fn to_real(self) -> Option<RealShardKind> {
68        match self {
69            ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70            ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71            ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72            ShardKind::Phantom => None,
73        }
74    }
75}
76
77/// Real shard kind for which a block may exist, see [`ShardKind`] for more details
78#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80    /// Beacon chain shard
81    BeaconChain,
82    /// Intermediate shard directly below the beacon chain that has child shards
83    IntermediateShard,
84    /// Leaf shard, which doesn't have child shards
85    LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89    #[inline(always)]
90    fn from(shard_kind: RealShardKind) -> Self {
91        match shard_kind {
92            RealShardKind::BeaconChain => ShardKind::BeaconChain,
93            RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94            RealShardKind::LeafShard => ShardKind::LeafShard,
95        }
96    }
97}
98
99/// Shard index
100#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl ShardIndex {
107    /// Beacon chain
108    pub const BEACON_CHAIN: Self = Self(0);
109    /// Max possible shard index
110    pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
111    /// Max possible number of shards
112    pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
113    /// Max possible number of addresses per shard
114    pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
115        NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
116
117    /// Create shard index from `u32`.
118    ///
119    /// Returns `None` if `shard_index > ShardIndex::MAX_SHARD_INDEX`
120    ///
121    /// This is typically only necessary for low-level code.
122    #[inline(always)]
123    pub const fn new(shard_index: u32) -> Option<Self> {
124        if shard_index > Self::MAX_SHARD_INDEX {
125            return None;
126        }
127
128        Some(Self(shard_index))
129    }
130
131    // TODO: Remove once traits work in const environment and `From` could be used
132    /// Convert shard index to `u32`.
133    ///
134    /// This is typically only necessary for low-level code.
135    #[inline(always)]
136    pub const fn as_u32(self) -> u32 {
137        self.0
138    }
139
140    /// Whether the shard index corresponds to the beacon chain
141    #[inline(always)]
142    pub const fn is_beacon_chain(&self) -> bool {
143        self.0 == Self::BEACON_CHAIN.0
144    }
145
146    /// Whether the shard index corresponds to an intermediate shard
147    #[inline(always)]
148    pub const fn is_intermediate_shard(&self) -> bool {
149        self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
150    }
151
152    /// Whether the shard index corresponds to an intermediate shard
153    #[inline(always)]
154    pub const fn is_leaf_shard(&self) -> bool {
155        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
156            return false;
157        }
158
159        self.0 & INTERMEDIATE_SHARD_MASK != 0
160    }
161
162    /// Whether the shard index corresponds to a real shard
163    #[inline(always)]
164    pub const fn is_real(&self) -> bool {
165        !self.is_phantom_shard()
166    }
167
168    /// Whether the shard index corresponds to a phantom shard
169    #[inline(always)]
170    pub const fn is_phantom_shard(&self) -> bool {
171        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
172            return false;
173        }
174
175        self.0 & INTERMEDIATE_SHARD_MASK == 0
176    }
177
178    /// Check if this shard is a child shard of `parent`
179    #[inline(always)]
180    pub const fn is_child_of(self, parent: Self) -> bool {
181        match self.shard_kind() {
182            Some(ShardKind::BeaconChain) => false,
183            Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
184            Some(ShardKind::LeafShard) => {
185                // Check that the least significant bits match
186                self.0 & INTERMEDIATE_SHARD_MASK == parent.0
187            }
188            None => false,
189        }
190    }
191
192    /// Get index of the parent shard (for leaf and intermediate shards)
193    #[inline(always)]
194    pub const fn parent_shard(self) -> Option<ShardIndex> {
195        match self.shard_kind()? {
196            ShardKind::BeaconChain => None,
197            ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
198            ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
199        }
200    }
201
202    /// Get shard kind
203    #[inline(always)]
204    pub const fn shard_kind(&self) -> Option<ShardKind> {
205        if self.0 == Self::BEACON_CHAIN.0 {
206            Some(ShardKind::BeaconChain)
207        } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
208            && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
209        {
210            Some(ShardKind::IntermediateShard)
211        } else if self.0 > Self::MAX_SHARD_INDEX {
212            None
213        } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
214            // Check if the least significant bits correspond to the beacon chain
215            Some(ShardKind::Phantom)
216        } else {
217            Some(ShardKind::LeafShard)
218        }
219    }
220}
221
222/// Unchecked number of shards in the network.
223///
224/// Should be converted into [`NumShards`] before use.
225#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
226#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
227#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
228#[repr(C)]
229pub struct NumShardsUnchecked {
230    /// The number of intermediate shards
231    pub intermediate_shards: u16,
232    /// The number of leaf shards per intermediate shard
233    pub leaf_shards_per_intermediate_shard: u16,
234}
235
236impl From<NumShards> for NumShardsUnchecked {
237    fn from(value: NumShards) -> Self {
238        Self {
239            intermediate_shards: value.intermediate_shards.get(),
240            leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
241        }
242    }
243}
244
245/// Number of shards in the network
246#[derive(Debug, Copy, Clone, Eq, PartialEq)]
247#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
248#[cfg_attr(feature = "serde", derive(Serialize))]
249pub struct NumShards {
250    /// The number of intermediate shards
251    intermediate_shards: NonZeroU16,
252    /// The number of leaf shards per intermediate shard
253    leaf_shards_per_intermediate_shard: NonZeroU16,
254}
255
256#[cfg(feature = "scale-codec")]
257impl Decode for NumShards {
258    fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
259        let intermediate_shards = Decode::decode(input)
260            .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
261        let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
262            error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
263        })?;
264
265        Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
266            .ok_or_else(|| "Invalid `NumShards`".into())
267    }
268}
269
270#[cfg(feature = "serde")]
271impl<'de> Deserialize<'de> for NumShards {
272    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273    where
274        D: Deserializer<'de>,
275    {
276        #[derive(Deserialize)]
277        struct NumShards {
278            intermediate_shards: NonZeroU16,
279            leaf_shards_per_intermediate_shard: NonZeroU16,
280        }
281
282        let num_shards_inner = NumShards::deserialize(deserializer)?;
283
284        Self::new(
285            num_shards_inner.intermediate_shards,
286            num_shards_inner.leaf_shards_per_intermediate_shard,
287        )
288        .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
289    }
290}
291
292impl TryFrom<NumShardsUnchecked> for NumShards {
293    type Error = ();
294
295    fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
296        Self::new(
297            NonZeroU16::new(value.intermediate_shards).ok_or(())?,
298            NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
299        )
300        .ok_or(())
301    }
302}
303
304impl NumShards {
305    /// Create a new instance from a number of intermediate shards and leaf shards per
306    /// intermediate shard.
307    ///
308    /// Returns `None` if inputs are invalid.
309    ///
310    /// This is typically only necessary for low-level code.
311    #[inline(always)]
312    pub const fn new(
313        intermediate_shards: NonZeroU16,
314        leaf_shards_per_intermediate_shard: NonZeroU16,
315    ) -> Option<Self> {
316        if intermediate_shards.get()
317            > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
318        {
319            return None;
320        }
321
322        let num_shards = Self {
323            intermediate_shards,
324            leaf_shards_per_intermediate_shard,
325        };
326
327        if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
328            return None;
329        }
330
331        Some(num_shards)
332    }
333
334    /// The number of intermediate shards
335    #[inline(always)]
336    pub const fn intermediate_shards(self) -> NonZeroU16 {
337        self.intermediate_shards
338    }
339    /// The number of leaf shards per intermediate shard
340    #[inline(always)]
341    pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
342        self.leaf_shards_per_intermediate_shard
343    }
344
345    /// Total number of leaf shards in the network
346    #[inline(always)]
347    pub const fn leaf_shards(&self) -> NonZeroU32 {
348        NonZeroU32::new(
349            self.intermediate_shards.get() as u32
350                * self.leaf_shards_per_intermediate_shard.get() as u32,
351        )
352        .expect("Not zero; qed")
353    }
354
355    /// Iterator over all intermediate shards
356    #[inline(always)]
357    pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
358        INTERMEDIATE_SHARDS_RANGE
359            .take(usize::from(self.intermediate_shards.get()))
360            .map(ShardIndex)
361    }
362
363    /// Iterator over all intermediate shards
364    #[inline(always)]
365    pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
366        self.iter_intermediate_shards()
367            .flat_map(|intermediate_shard| {
368                (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
369                    move |leaf_shard_index| {
370                        ShardIndex(
371                            (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
372                        )
373                    },
374                )
375            })
376    }
377
378    /// Derive shard index that should be used in a solution
379    #[inline]
380    pub fn derive_shard_index(
381        &self,
382        public_key_hash: &Blake3Hash,
383        shard_commitments_root: &ShardCommitmentHash,
384        shard_membership_entropy: &ShardMembershipEntropy,
385        history_size: HistorySize,
386    ) -> ShardIndex {
387        let hash = single_block_keyed_hash(public_key_hash, &{
388            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
389                + ShardMembershipEntropy::SIZE
390                + HistorySize::SIZE as usize];
391            bytes_to_hash[..ShardCommitmentHash::SIZE]
392                .copy_from_slice(shard_commitments_root.as_bytes());
393            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
394                .copy_from_slice(shard_membership_entropy.as_bytes());
395            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
396                .copy_from_slice(history_size.as_bytes());
397            bytes_to_hash
398        })
399        .expect("Input is smaller than block size; qed");
400        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
401        // power of two
402        let shard_index_offset =
403            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
404
405        self.iter_leaf_shards()
406            .nth(shard_index_offset as usize)
407            .unwrap_or(ShardIndex::BEACON_CHAIN)
408    }
409
410    /// Derive shard commitment index that should be used in a solution.
411    ///
412    /// Returned index is always within `0`..[`SolutionShardCommitment::NUM_LEAVES`] range.
413    #[inline]
414    pub fn derive_shard_commitment_index(
415        &self,
416        public_key_hash: &Blake3Hash,
417        shard_commitments_root: &ShardCommitmentHash,
418        shard_membership_entropy: &ShardMembershipEntropy,
419        history_size: HistorySize,
420    ) -> u32 {
421        let hash = single_block_keyed_hash(public_key_hash, &{
422            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
423                + ShardMembershipEntropy::SIZE
424                + HistorySize::SIZE as usize];
425            bytes_to_hash[..ShardCommitmentHash::SIZE]
426                .copy_from_slice(shard_commitments_root.as_bytes());
427            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
428                .copy_from_slice(shard_membership_entropy.as_bytes());
429            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
430                .copy_from_slice(history_size.as_bytes());
431            bytes_to_hash
432        })
433        .expect("Input is smaller than block size; qed");
434        const {
435            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
436        }
437        u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
438            % SolutionShardCommitment::NUM_LEAVES as u32
439    }
440
441    /// More efficient version of [`Self::derive_shard_index()`] and
442    /// [`Self::derive_shard_commitment_index()`] in a single call, see those functions for details
443    #[inline]
444    pub fn derive_shard_index_and_shard_commitment_index(
445        &self,
446        public_key_hash: &Blake3Hash,
447        shard_commitments_root: &ShardCommitmentHash,
448        shard_membership_entropy: &ShardMembershipEntropy,
449        history_size: HistorySize,
450    ) -> (ShardIndex, u32) {
451        let hash = single_block_keyed_hash(public_key_hash, &{
452            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
453                + ShardMembershipEntropy::SIZE
454                + HistorySize::SIZE as usize];
455            bytes_to_hash[..ShardCommitmentHash::SIZE]
456                .copy_from_slice(shard_commitments_root.as_bytes());
457            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
458                .copy_from_slice(shard_membership_entropy.as_bytes());
459            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
460                .copy_from_slice(history_size.as_bytes());
461            bytes_to_hash
462        })
463        .expect("Input is smaller than block size; qed");
464
465        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
466        // power of two
467        let shard_index_offset =
468            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
469
470        let shard_index = self
471            .iter_leaf_shards()
472            .nth(shard_index_offset as usize)
473            .unwrap_or(ShardIndex::BEACON_CHAIN);
474
475        const {
476            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
477        }
478        let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
479            % SolutionShardCommitment::NUM_LEAVES as u32;
480
481        (shard_index, shard_commitment_index)
482    }
483}