Skip to main content

ab_core_primitives/
shard.rs

1//! Shard-related primitives
2
3use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21/// A kind of shard.
22///
23/// Schematically, the hierarchy of shards is as follows:
24/// ```text
25///                          Beacon chain
26///                          /          \
27///      Intermediate shard 1            Intermediate shard 2
28///              /  \                            /  \
29/// Leaf shard 11   Leaf shard 12   Leaf shard 22   Leaf shard 22
30/// ```
31///
32/// Beacon chain has index 0, intermediate shards indices 1..=1023. Leaf shards share the same least
33/// significant 10 bits as their respective intermediate shards, so leaf shards of intermediate
34/// shard 1 have indices like 1025,2049,3097,etc.
35///
36/// If represented as least significant bits first (as it will be in the formatted address):
37/// ```text
38/// Intermediate shard bits
39///     \            /
40///      1000_0000_0001_0000_0000
41///                 /            \
42///                Leaf shard bits
43/// ```
44///
45/// Note that shards that have 10 least significant bits set to 0 (corresponds to the beacon chain)
46/// are not leaf shards, in fact, they are not even physical shards that have nodes in general. The
47/// meaning of these shards is TBD, currently they are called "phantom" shards and may end up
48/// containing some system contracts with special meaning, but no blocks will ever exist for such
49/// shards.
50#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52    /// Beacon chain shard
53    BeaconChain,
54    /// Intermediate shard directly below the beacon chain that has child shards
55    IntermediateShard,
56    /// Leaf shard, which doesn't have child shards
57    LeafShard,
58    /// TODO
59    Phantom,
60}
61
62impl ShardKind {
63    /// Try to convert to real shard kind.
64    ///
65    /// Returns `None` for phantom shard.
66    #[inline(always)]
67    pub fn to_real(self) -> Option<RealShardKind> {
68        match self {
69            ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70            ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71            ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72            ShardKind::Phantom => None,
73        }
74    }
75}
76
77/// Real shard kind for which a block may exist, see [`ShardKind`] for more details
78#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80    /// Beacon chain shard
81    BeaconChain,
82    /// Intermediate shard directly below the beacon chain that has child shards
83    IntermediateShard,
84    /// Leaf shard, which doesn't have child shards
85    LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89    #[inline(always)]
90    fn from(shard_kind: RealShardKind) -> Self {
91        match shard_kind {
92            RealShardKind::BeaconChain => ShardKind::BeaconChain,
93            RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94            RealShardKind::LeafShard => ShardKind::LeafShard,
95        }
96    }
97}
98
99/// Shard index
100#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl const Default for ShardIndex {
107    #[inline(always)]
108    fn default() -> Self {
109        Self::BEACON_CHAIN
110    }
111}
112
113impl const From<ShardIndex> for u32 {
114    #[inline(always)]
115    fn from(shard_index: ShardIndex) -> Self {
116        shard_index.0
117    }
118}
119
120impl ShardIndex {
121    /// Beacon chain
122    pub const BEACON_CHAIN: Self = Self(0);
123    /// Max possible shard index
124    pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
125    /// Max possible number of shards
126    pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
127    /// Max possible number of addresses per shard
128    pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
129        NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
130
131    /// Create shard index from `u32`.
132    ///
133    /// Returns `None` if `shard_index > ShardIndex::MAX_SHARD_INDEX`
134    ///
135    /// This is typically only necessary for low-level code.
136    #[inline(always)]
137    pub const fn new(shard_index: u32) -> Option<Self> {
138        if shard_index > Self::MAX_SHARD_INDEX {
139            return None;
140        }
141
142        Some(Self(shard_index))
143    }
144
145    /// Whether the shard index corresponds to the beacon chain
146    #[inline(always)]
147    pub const fn is_beacon_chain(&self) -> bool {
148        self.0 == Self::BEACON_CHAIN.0
149    }
150
151    /// Whether the shard index corresponds to an intermediate shard
152    #[inline(always)]
153    pub const fn is_intermediate_shard(&self) -> bool {
154        self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
155    }
156
157    /// Whether the shard index corresponds to an intermediate shard
158    #[inline(always)]
159    pub const fn is_leaf_shard(&self) -> bool {
160        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
161            return false;
162        }
163
164        self.0 & INTERMEDIATE_SHARD_MASK != 0
165    }
166
167    /// Whether the shard index corresponds to a real shard
168    #[inline(always)]
169    pub const fn is_real(&self) -> bool {
170        !self.is_phantom_shard()
171    }
172
173    /// Whether the shard index corresponds to a phantom shard
174    #[inline(always)]
175    pub const fn is_phantom_shard(&self) -> bool {
176        if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
177            return false;
178        }
179
180        self.0 & INTERMEDIATE_SHARD_MASK == 0
181    }
182
183    /// Check if this shard is a child shard of `parent`
184    #[inline(always)]
185    pub const fn is_child_of(self, parent: Self) -> bool {
186        match self.shard_kind() {
187            Some(ShardKind::BeaconChain) => false,
188            Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
189            Some(ShardKind::LeafShard) => {
190                // Check that the least significant bits match
191                self.0 & INTERMEDIATE_SHARD_MASK == parent.0
192            }
193            None => false,
194        }
195    }
196
197    /// Get index of the parent shard (for leaf and intermediate shards)
198    #[inline(always)]
199    pub const fn parent_shard(self) -> Option<ShardIndex> {
200        match self.shard_kind()? {
201            ShardKind::BeaconChain => None,
202            ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
203            ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
204        }
205    }
206
207    /// Get shard kind
208    #[inline(always)]
209    pub const fn shard_kind(&self) -> Option<ShardKind> {
210        if self.0 == Self::BEACON_CHAIN.0 {
211            Some(ShardKind::BeaconChain)
212        } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
213            && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
214        {
215            Some(ShardKind::IntermediateShard)
216        } else if self.0 > Self::MAX_SHARD_INDEX {
217            None
218        } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
219            // Check if the least significant bits correspond to the beacon chain
220            Some(ShardKind::Phantom)
221        } else {
222            Some(ShardKind::LeafShard)
223        }
224    }
225}
226
227/// Unchecked number of shards in the network.
228///
229/// Should be converted into [`NumShards`] before use.
230#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
231#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233#[repr(C)]
234pub struct NumShardsUnchecked {
235    /// The number of intermediate shards
236    pub intermediate_shards: u16,
237    /// The number of leaf shards per intermediate shard
238    pub leaf_shards_per_intermediate_shard: u16,
239}
240
241impl From<NumShards> for NumShardsUnchecked {
242    fn from(value: NumShards) -> Self {
243        Self {
244            intermediate_shards: value.intermediate_shards.get(),
245            leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
246        }
247    }
248}
249
250/// Number of shards in the network
251#[derive(Debug, Copy, Clone, Eq, PartialEq)]
252#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
253#[cfg_attr(feature = "serde", derive(Serialize))]
254pub struct NumShards {
255    /// The number of intermediate shards
256    intermediate_shards: NonZeroU16,
257    /// The number of leaf shards per intermediate shard
258    leaf_shards_per_intermediate_shard: NonZeroU16,
259}
260
261#[cfg(feature = "scale-codec")]
262impl Decode for NumShards {
263    fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
264        let intermediate_shards = Decode::decode(input)
265            .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
266        let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
267            error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
268        })?;
269
270        Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
271            .ok_or_else(|| "Invalid `NumShards`".into())
272    }
273}
274
275#[cfg(feature = "serde")]
276impl<'de> Deserialize<'de> for NumShards {
277    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
278    where
279        D: Deserializer<'de>,
280    {
281        #[derive(Deserialize)]
282        struct NumShards {
283            intermediate_shards: NonZeroU16,
284            leaf_shards_per_intermediate_shard: NonZeroU16,
285        }
286
287        let num_shards_inner = NumShards::deserialize(deserializer)?;
288
289        Self::new(
290            num_shards_inner.intermediate_shards,
291            num_shards_inner.leaf_shards_per_intermediate_shard,
292        )
293        .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
294    }
295}
296
297impl TryFrom<NumShardsUnchecked> for NumShards {
298    type Error = ();
299
300    fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
301        Self::new(
302            NonZeroU16::new(value.intermediate_shards).ok_or(())?,
303            NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
304        )
305        .ok_or(())
306    }
307}
308
309impl NumShards {
310    /// Create a new instance from a number of intermediate shards and leaf shards per
311    /// intermediate shard.
312    ///
313    /// Returns `None` if inputs are invalid.
314    ///
315    /// This is typically only necessary for low-level code.
316    #[inline(always)]
317    pub const fn new(
318        intermediate_shards: NonZeroU16,
319        leaf_shards_per_intermediate_shard: NonZeroU16,
320    ) -> Option<Self> {
321        if intermediate_shards.get()
322            > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
323        {
324            return None;
325        }
326
327        let num_shards = Self {
328            intermediate_shards,
329            leaf_shards_per_intermediate_shard,
330        };
331
332        if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
333            return None;
334        }
335
336        Some(num_shards)
337    }
338
339    /// The number of intermediate shards
340    #[inline(always)]
341    pub const fn intermediate_shards(self) -> NonZeroU16 {
342        self.intermediate_shards
343    }
344    /// The number of leaf shards per intermediate shard
345    #[inline(always)]
346    pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
347        self.leaf_shards_per_intermediate_shard
348    }
349
350    /// Total number of leaf shards in the network
351    #[inline(always)]
352    pub const fn leaf_shards(&self) -> NonZeroU32 {
353        NonZeroU32::new(
354            self.intermediate_shards.get() as u32
355                * self.leaf_shards_per_intermediate_shard.get() as u32,
356        )
357        .expect("Not zero; qed")
358    }
359
360    /// Iterator over all intermediate shards
361    #[inline(always)]
362    pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
363        INTERMEDIATE_SHARDS_RANGE
364            .take(usize::from(self.intermediate_shards.get()))
365            .map(ShardIndex)
366    }
367
368    /// Iterator over all intermediate shards
369    #[inline(always)]
370    pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
371        self.iter_intermediate_shards()
372            .flat_map(|intermediate_shard| {
373                (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
374                    move |leaf_shard_index| {
375                        ShardIndex(
376                            (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
377                        )
378                    },
379                )
380            })
381    }
382
383    /// Derive shard index that should be used in a solution
384    #[inline]
385    pub fn derive_shard_index(
386        &self,
387        public_key_hash: &Blake3Hash,
388        shard_commitments_root: &ShardCommitmentHash,
389        shard_membership_entropy: &ShardMembershipEntropy,
390        history_size: HistorySize,
391    ) -> ShardIndex {
392        let hash = single_block_keyed_hash(public_key_hash, &{
393            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
394                + ShardMembershipEntropy::SIZE
395                + HistorySize::SIZE as usize];
396            bytes_to_hash[..ShardCommitmentHash::SIZE]
397                .copy_from_slice(shard_commitments_root.as_bytes());
398            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
399                .copy_from_slice(shard_membership_entropy.as_bytes());
400            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
401                .copy_from_slice(history_size.as_bytes());
402            bytes_to_hash
403        })
404        .expect("Input is smaller than block size; qed");
405        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
406        // power of two
407        let shard_index_offset =
408            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
409
410        self.iter_leaf_shards()
411            .nth(shard_index_offset as usize)
412            .unwrap_or(ShardIndex::BEACON_CHAIN)
413    }
414
415    /// Derive shard commitment index that should be used in a solution.
416    ///
417    /// Returned index is always within `0`..[`SolutionShardCommitment::NUM_LEAVES`] range.
418    #[inline]
419    pub fn derive_shard_commitment_index(
420        &self,
421        public_key_hash: &Blake3Hash,
422        shard_commitments_root: &ShardCommitmentHash,
423        shard_membership_entropy: &ShardMembershipEntropy,
424        history_size: HistorySize,
425    ) -> u32 {
426        let hash = single_block_keyed_hash(public_key_hash, &{
427            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
428                + ShardMembershipEntropy::SIZE
429                + HistorySize::SIZE as usize];
430            bytes_to_hash[..ShardCommitmentHash::SIZE]
431                .copy_from_slice(shard_commitments_root.as_bytes());
432            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
433                .copy_from_slice(shard_membership_entropy.as_bytes());
434            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
435                .copy_from_slice(history_size.as_bytes());
436            bytes_to_hash
437        })
438        .expect("Input is smaller than block size; qed");
439        const {
440            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
441        }
442        u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
443            % SolutionShardCommitment::NUM_LEAVES as u32
444    }
445
446    /// More efficient version of [`Self::derive_shard_index()`] and
447    /// [`Self::derive_shard_commitment_index()`] in a single call, see those functions for details
448    #[inline]
449    pub fn derive_shard_index_and_shard_commitment_index(
450        &self,
451        public_key_hash: &Blake3Hash,
452        shard_commitments_root: &ShardCommitmentHash,
453        shard_membership_entropy: &ShardMembershipEntropy,
454        history_size: HistorySize,
455    ) -> (ShardIndex, u32) {
456        let hash = single_block_keyed_hash(public_key_hash, &{
457            let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
458                + ShardMembershipEntropy::SIZE
459                + HistorySize::SIZE as usize];
460            bytes_to_hash[..ShardCommitmentHash::SIZE]
461                .copy_from_slice(shard_commitments_root.as_bytes());
462            bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
463                .copy_from_slice(shard_membership_entropy.as_bytes());
464            bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
465                .copy_from_slice(history_size.as_bytes());
466            bytes_to_hash
467        })
468        .expect("Input is smaller than block size; qed");
469
470        // Going through `NanoU256` because the total number of shards is not guaranteed to be a
471        // power of two
472        let shard_index_offset =
473            NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
474
475        let shard_index = self
476            .iter_leaf_shards()
477            .nth(shard_index_offset as usize)
478            .unwrap_or(ShardIndex::BEACON_CHAIN);
479
480        const {
481            assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
482        }
483        let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
484            % SolutionShardCommitment::NUM_LEAVES as u32;
485
486        (shard_index, shard_commitment_index)
487    }
488}