Skip to main content

ab_core_primitives/
shard.rs

1//! Shard-related primitives
2
3use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21/// A kind of shard.
22///
23/// Schematically, the hierarchy of shards is as follows:
24/// ```text
25///                          Beacon chain
26///                          /          \
27///      Intermediate shard 1            Intermediate shard 2
28///              /  \                            /  \
29/// Leaf shard 11   Leaf shard 12   Leaf shard 22   Leaf shard 22
30/// ```
31///
32/// Beacon chain has index 0, intermediate shards indices 1..=1023. Leaf shards share the same least
33/// significant 10 bits as their respective intermediate shards, so leaf shards of intermediate
34/// shard 1 have indices like 1025,2049,3097,etc.
35///
36/// If represented as least significant bits first (as it will be in the formatted address):
37/// ```text
38/// Intermediate shard bits
39///     \            /
40///      1000_0000_0001_0000_0000
41///                 /            \
42///                Leaf shard bits
43/// ```
44///
45/// Note that shards that have 10 least significant bits set to 0 (corresponds to the beacon chain)
46/// are not leaf shards, in fact, they are not even physical shards that have nodes in general. The
47/// meaning of these shards is TBD, currently they are called "phantom" shards and may end up
48/// containing some system contracts with special meaning, but no blocks will ever exist for such
49/// shards.
50#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52    /// Beacon chain shard
53    BeaconChain,
54    /// Intermediate shard directly below the beacon chain that has child shards
55    IntermediateShard,
56    /// Leaf shard, which doesn't have child shards
57    LeafShard,
58    /// TODO
59    Phantom,
60}
61
62impl ShardKind {
63    /// Try to convert to real shard kind.
64    ///
65    /// Returns `None` for phantom shard.
66    #[inline(always)]
67    pub fn to_real(self) -> Option<RealShardKind> {
68        match self {
69            ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70            ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71            ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72            ShardKind::Phantom => None,
73        }
74    }
75}
76
77/// Real shard kind for which a block may exist, see [`ShardKind`] for more details
78#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80    /// Beacon chain shard
81    BeaconChain,
82    /// Intermediate shard directly below the beacon chain that has child shards
83    IntermediateShard,
84    /// Leaf shard, which doesn't have child shards
85    LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89    #[inline(always)]
90    fn from(shard_kind: RealShardKind) -> Self {
91        match shard_kind {
92            RealShardKind::BeaconChain => ShardKind::BeaconChain,
93            RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94            RealShardKind::LeafShard => ShardKind::LeafShard,
95        }
96    }
97}
98
99/// Shard index
100#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl const From<ShardIndex> for u32 {
107    #[inline(always)]
108    fn from(shard_index: ShardIndex) -> Self {
109        shard_index.0
110    }
111}
112
113impl ShardIndex {
114    /// Beacon chain
115    pub const BEACON_CHAIN: Self = Self(0);
116    /// Max possible shard index
117    pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
118    /// Max possible number of shards
119    pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
120    /// Max possible number of addresses per shard
121    pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
122        NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
123
124    /// Create shard index from `u32`.
125    ///
126    /// Returns `None` if `shard_index > ShardIndex::MAX_SHARD_INDEX`
127    ///
128    /// This is typically only necessary for low-level code.
129    #[inline(always)]
130    pub const fn new(shard_index: u32) -> Option<Self> {
131        if shard_index > Self::MAX_SHARD_INDEX {
132            return None;
133        }
134
135        Some(Self(shard_index))
136    }
137
138    /// Whether the shard index corresponds to the beacon chain
139    #[inline(always)]
140    pub const fn is_beacon_chain(&self) -> bool {
141        self.0 == Self::BEACON_CHAIN.0
142    }
143
144    /// Whether the shard index corresponds to an intermediate shard
145    #[inline(always)]
146    pub const fn is_intermediate_shard(&self) -> bool {
147        self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
148    }
149
150    /// Whether the shard index corresponds to an intermediate shard
151    #[inline(always)]
152    pub const fn is_leaf_shard(&self) -> bool {
153        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
154            return false;
155        }
156
157        self.0 & INTERMEDIATE_SHARD_MASK != 0
158    }
159
160    /// Whether the shard index corresponds to a real shard
161    #[inline(always)]
162    pub const fn is_real(&self) -> bool {
163        !self.is_phantom_shard()
164    }
165
166    /// Whether the shard index corresponds to a phantom shard
167    #[inline(always)]
168    pub const fn is_phantom_shard(&self) -> bool {
169        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
170            return false;
171        }
172
173        self.0 & INTERMEDIATE_SHARD_MASK == 0
174    }
175
176    /// Check if this shard is a child shard of `parent`
177    #[inline(always)]
178    pub const fn is_child_of(self, parent: Self) -> bool {
179        match self.shard_kind() {
180            Some(ShardKind::BeaconChain) => false,
181            Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
182            Some(ShardKind::LeafShard) => {
183                // Check that the least significant bits match
184                self.0 & INTERMEDIATE_SHARD_MASK == parent.0
185            }
186            None => false,
187        }
188    }
189
190    /// Get index of the parent shard (for leaf and intermediate shards)
191    #[inline(always)]
192    pub const fn parent_shard(self) -> Option<ShardIndex> {
193        match self.shard_kind()? {
194            ShardKind::BeaconChain => None,
195            ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
196            ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
197        }
198    }
199
200    /// Get shard kind
201    #[inline(always)]
202    pub const fn shard_kind(&self) -> Option<ShardKind> {
203        if self.0 == Self::BEACON_CHAIN.0 {
204            Some(ShardKind::BeaconChain)
205        } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
206            && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
207        {
208            Some(ShardKind::IntermediateShard)
209        } else if self.0 > Self::MAX_SHARD_INDEX {
210            None
211        } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
212            // Check if the least significant bits correspond to the beacon chain
213            Some(ShardKind::Phantom)
214        } else {
215            Some(ShardKind::LeafShard)
216        }
217    }
218}
219
220/// Unchecked number of shards in the network.
221///
222/// Should be converted into [`NumShards`] before use.
223#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
224#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
225#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
226#[repr(C)]
227pub struct NumShardsUnchecked {
228    /// The number of intermediate shards
229    pub intermediate_shards: u16,
230    /// The number of leaf shards per intermediate shard
231    pub leaf_shards_per_intermediate_shard: u16,
232}
233
234impl From<NumShards> for NumShardsUnchecked {
235    fn from(value: NumShards) -> Self {
236        Self {
237            intermediate_shards: value.intermediate_shards.get(),
238            leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
239        }
240    }
241}
242
243/// Number of shards in the network
244#[derive(Debug, Copy, Clone, Eq, PartialEq)]
245#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
246#[cfg_attr(feature = "serde", derive(Serialize))]
247pub struct NumShards {
248    /// The number of intermediate shards
249    intermediate_shards: NonZeroU16,
250    /// The number of leaf shards per intermediate shard
251    leaf_shards_per_intermediate_shard: NonZeroU16,
252}
253
254#[cfg(feature = "scale-codec")]
255impl Decode for NumShards {
256    fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
257        let intermediate_shards = Decode::decode(input)
258            .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
259        let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
260            error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
261        })?;
262
263        Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
264            .ok_or_else(|| "Invalid `NumShards`".into())
265    }
266}
267
268#[cfg(feature = "serde")]
269impl<'de> Deserialize<'de> for NumShards {
270    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
271    where
272        D: Deserializer<'de>,
273    {
274        #[derive(Deserialize)]
275        struct NumShards {
276            intermediate_shards: NonZeroU16,
277            leaf_shards_per_intermediate_shard: NonZeroU16,
278        }
279
280        let num_shards_inner = NumShards::deserialize(deserializer)?;
281
282        Self::new(
283            num_shards_inner.intermediate_shards,
284            num_shards_inner.leaf_shards_per_intermediate_shard,
285        )
286        .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
287    }
288}
289
290impl TryFrom<NumShardsUnchecked> for NumShards {
291    type Error = ();
292
293    fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
294        Self::new(
295            NonZeroU16::new(value.intermediate_shards).ok_or(())?,
296            NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
297        )
298        .ok_or(())
299    }
300}
301
302impl NumShards {
303    /// Create a new instance from a number of intermediate shards and leaf shards per
304    /// intermediate shard.
305    ///
306    /// Returns `None` if inputs are invalid.
307    ///
308    /// This is typically only necessary for low-level code.
309    #[inline(always)]
310    pub const fn new(
311        intermediate_shards: NonZeroU16,
312        leaf_shards_per_intermediate_shard: NonZeroU16,
313    ) -> Option<Self> {
314        if intermediate_shards.get()
315            > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
316        {
317            return None;
318        }
319
320        let num_shards = Self {
321            intermediate_shards,
322            leaf_shards_per_intermediate_shard,
323        };
324
325        if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
326            return None;
327        }
328
329        Some(num_shards)
330    }
331
332    /// The number of intermediate shards
333    #[inline(always)]
334    pub const fn intermediate_shards(self) -> NonZeroU16 {
335        self.intermediate_shards
336    }
337    /// The number of leaf shards per intermediate shard
338    #[inline(always)]
339    pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
340        self.leaf_shards_per_intermediate_shard
341    }
342
343    /// Total number of leaf shards in the network
344    #[inline(always)]
345    pub const fn leaf_shards(&self) -> NonZeroU32 {
346        NonZeroU32::new(
347            self.intermediate_shards.get() as u32
348                * self.leaf_shards_per_intermediate_shard.get() as u32,
349        )
350        .expect("Not zero; qed")
351    }
352
353    /// Iterator over all intermediate shards
354    #[inline(always)]
355    pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
356        INTERMEDIATE_SHARDS_RANGE
357            .take(usize::from(self.intermediate_shards.get()))
358            .map(ShardIndex)
359    }
360
361    /// Iterator over all intermediate shards
362    #[inline(always)]
363    pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
364        self.iter_intermediate_shards()
365            .flat_map(|intermediate_shard| {
366                (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
367                    move |leaf_shard_index| {
368                        ShardIndex(
369                            (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
370                        )
371                    },
372                )
373            })
374    }
375
376    /// Derive shard index that should be used in a solution
377    #[inline]
378    pub fn derive_shard_index(
379        &self,
380        public_key_hash: &Blake3Hash,
381        shard_commitments_root: &ShardCommitmentHash,
382        shard_membership_entropy: &ShardMembershipEntropy,
383        history_size: HistorySize,
384    ) -> ShardIndex {
385        let hash = single_block_keyed_hash(public_key_hash, &{
386            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
387                + ShardMembershipEntropy::SIZE
388                + HistorySize::SIZE as usize];
389            bytes_to_hash[..ShardCommitmentHash::SIZE]
390                .copy_from_slice(shard_commitments_root.as_bytes());
391            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
392                .copy_from_slice(shard_membership_entropy.as_bytes());
393            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
394                .copy_from_slice(history_size.as_bytes());
395            bytes_to_hash
396        })
397        .expect("Input is smaller than block size; qed");
398        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
399        // power of two
400        let shard_index_offset =
401            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
402
403        self.iter_leaf_shards()
404            .nth(shard_index_offset as usize)
405            .unwrap_or(ShardIndex::BEACON_CHAIN)
406    }
407
408    /// Derive shard commitment index that should be used in a solution.
409    ///
410    /// Returned index is always within `0`..[`SolutionShardCommitment::NUM_LEAVES`] range.
411    #[inline]
412    pub fn derive_shard_commitment_index(
413        &self,
414        public_key_hash: &Blake3Hash,
415        shard_commitments_root: &ShardCommitmentHash,
416        shard_membership_entropy: &ShardMembershipEntropy,
417        history_size: HistorySize,
418    ) -> u32 {
419        let hash = single_block_keyed_hash(public_key_hash, &{
420            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
421                + ShardMembershipEntropy::SIZE
422                + HistorySize::SIZE as usize];
423            bytes_to_hash[..ShardCommitmentHash::SIZE]
424                .copy_from_slice(shard_commitments_root.as_bytes());
425            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
426                .copy_from_slice(shard_membership_entropy.as_bytes());
427            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
428                .copy_from_slice(history_size.as_bytes());
429            bytes_to_hash
430        })
431        .expect("Input is smaller than block size; qed");
432        const {
433            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
434        }
435        u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
436            % SolutionShardCommitment::NUM_LEAVES as u32
437    }
438
439    /// More efficient version of [`Self::derive_shard_index()`] and
440    /// [`Self::derive_shard_commitment_index()`] in a single call, see those functions for details
441    #[inline]
442    pub fn derive_shard_index_and_shard_commitment_index(
443        &self,
444        public_key_hash: &Blake3Hash,
445        shard_commitments_root: &ShardCommitmentHash,
446        shard_membership_entropy: &ShardMembershipEntropy,
447        history_size: HistorySize,
448    ) -> (ShardIndex, u32) {
449        let hash = single_block_keyed_hash(public_key_hash, &{
450            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
451                + ShardMembershipEntropy::SIZE
452                + HistorySize::SIZE as usize];
453            bytes_to_hash[..ShardCommitmentHash::SIZE]
454                .copy_from_slice(shard_commitments_root.as_bytes());
455            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
456                .copy_from_slice(shard_membership_entropy.as_bytes());
457            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
458                .copy_from_slice(history_size.as_bytes());
459            bytes_to_hash
460        })
461        .expect("Input is smaller than block size; qed");
462
463        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
464        // power of two
465        let shard_index_offset =
466            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
467
468        let shard_index = self
469            .iter_leaf_shards()
470            .nth(shard_index_offset as usize)
471            .unwrap_or(ShardIndex::BEACON_CHAIN);
472
473        const {
474            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
475        }
476        let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
477            % SolutionShardCommitment::NUM_LEAVES as u32;
478
479        (shard_index, shard_commitment_index)
480    }
481}