1use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52 BeaconChain,
54 IntermediateShard,
56 LeafShard,
58 Phantom,
60}
61
62impl ShardKind {
63 #[inline(always)]
67 pub fn to_real(self) -> Option<RealShardKind> {
68 match self {
69 ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70 ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71 ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72 ShardKind::Phantom => None,
73 }
74 }
75}
76
77#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80 BeaconChain,
82 IntermediateShard,
84 LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89 #[inline(always)]
90 fn from(shard_kind: RealShardKind) -> Self {
91 match shard_kind {
92 RealShardKind::BeaconChain => ShardKind::BeaconChain,
93 RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94 RealShardKind::LeafShard => ShardKind::LeafShard,
95 }
96 }
97}
98
99#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl const From<ShardIndex> for u32 {
107 #[inline(always)]
108 fn from(shard_index: ShardIndex) -> Self {
109 shard_index.0
110 }
111}
112
113impl ShardIndex {
114 pub const BEACON_CHAIN: Self = Self(0);
116 pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
118 pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
120 pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
122 NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
123
124 #[inline(always)]
130 pub const fn new(shard_index: u32) -> Option<Self> {
131 if shard_index > Self::MAX_SHARD_INDEX {
132 return None;
133 }
134
135 Some(Self(shard_index))
136 }
137
138 #[inline(always)]
140 pub const fn is_beacon_chain(&self) -> bool {
141 self.0 == Self::BEACON_CHAIN.0
142 }
143
144 #[inline(always)]
146 pub const fn is_intermediate_shard(&self) -> bool {
147 self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
148 }
149
150 #[inline(always)]
152 pub const fn is_leaf_shard(&self) -> bool {
153 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
154 return false;
155 }
156
157 self.0 & INTERMEDIATE_SHARD_MASK != 0
158 }
159
160 #[inline(always)]
162 pub const fn is_real(&self) -> bool {
163 !self.is_phantom_shard()
164 }
165
166 #[inline(always)]
168 pub const fn is_phantom_shard(&self) -> bool {
169 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
170 return false;
171 }
172
173 self.0 & INTERMEDIATE_SHARD_MASK == 0
174 }
175
176 #[inline(always)]
178 pub const fn is_child_of(self, parent: Self) -> bool {
179 match self.shard_kind() {
180 Some(ShardKind::BeaconChain) => false,
181 Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
182 Some(ShardKind::LeafShard) => {
183 self.0 & INTERMEDIATE_SHARD_MASK == parent.0
185 }
186 None => false,
187 }
188 }
189
190 #[inline(always)]
192 pub const fn parent_shard(self) -> Option<ShardIndex> {
193 match self.shard_kind()? {
194 ShardKind::BeaconChain => None,
195 ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
196 ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
197 }
198 }
199
200 #[inline(always)]
202 pub const fn shard_kind(&self) -> Option<ShardKind> {
203 if self.0 == Self::BEACON_CHAIN.0 {
204 Some(ShardKind::BeaconChain)
205 } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
206 && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
207 {
208 Some(ShardKind::IntermediateShard)
209 } else if self.0 > Self::MAX_SHARD_INDEX {
210 None
211 } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
212 Some(ShardKind::Phantom)
214 } else {
215 Some(ShardKind::LeafShard)
216 }
217 }
218}
219
220#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
224#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
225#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
226#[repr(C)]
227pub struct NumShardsUnchecked {
228 pub intermediate_shards: u16,
230 pub leaf_shards_per_intermediate_shard: u16,
232}
233
234impl From<NumShards> for NumShardsUnchecked {
235 fn from(value: NumShards) -> Self {
236 Self {
237 intermediate_shards: value.intermediate_shards.get(),
238 leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
239 }
240 }
241}
242
243#[derive(Debug, Copy, Clone, Eq, PartialEq)]
245#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
246#[cfg_attr(feature = "serde", derive(Serialize))]
247pub struct NumShards {
248 intermediate_shards: NonZeroU16,
250 leaf_shards_per_intermediate_shard: NonZeroU16,
252}
253
254#[cfg(feature = "scale-codec")]
255impl Decode for NumShards {
256 fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
257 let intermediate_shards = Decode::decode(input)
258 .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
259 let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
260 error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
261 })?;
262
263 Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
264 .ok_or_else(|| "Invalid `NumShards`".into())
265 }
266}
267
268#[cfg(feature = "serde")]
269impl<'de> Deserialize<'de> for NumShards {
270 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
271 where
272 D: Deserializer<'de>,
273 {
274 #[derive(Deserialize)]
275 struct NumShards {
276 intermediate_shards: NonZeroU16,
277 leaf_shards_per_intermediate_shard: NonZeroU16,
278 }
279
280 let num_shards_inner = NumShards::deserialize(deserializer)?;
281
282 Self::new(
283 num_shards_inner.intermediate_shards,
284 num_shards_inner.leaf_shards_per_intermediate_shard,
285 )
286 .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
287 }
288}
289
290impl TryFrom<NumShardsUnchecked> for NumShards {
291 type Error = ();
292
293 fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
294 Self::new(
295 NonZeroU16::new(value.intermediate_shards).ok_or(())?,
296 NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
297 )
298 .ok_or(())
299 }
300}
301
302impl NumShards {
303 #[inline(always)]
310 pub const fn new(
311 intermediate_shards: NonZeroU16,
312 leaf_shards_per_intermediate_shard: NonZeroU16,
313 ) -> Option<Self> {
314 if intermediate_shards.get()
315 > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
316 {
317 return None;
318 }
319
320 let num_shards = Self {
321 intermediate_shards,
322 leaf_shards_per_intermediate_shard,
323 };
324
325 if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
326 return None;
327 }
328
329 Some(num_shards)
330 }
331
332 #[inline(always)]
334 pub const fn intermediate_shards(self) -> NonZeroU16 {
335 self.intermediate_shards
336 }
337 #[inline(always)]
339 pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
340 self.leaf_shards_per_intermediate_shard
341 }
342
343 #[inline(always)]
345 pub const fn leaf_shards(&self) -> NonZeroU32 {
346 NonZeroU32::new(
347 self.intermediate_shards.get() as u32
348 * self.leaf_shards_per_intermediate_shard.get() as u32,
349 )
350 .expect("Not zero; qed")
351 }
352
353 #[inline(always)]
355 pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
356 INTERMEDIATE_SHARDS_RANGE
357 .take(usize::from(self.intermediate_shards.get()))
358 .map(ShardIndex)
359 }
360
361 #[inline(always)]
363 pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
364 self.iter_intermediate_shards()
365 .flat_map(|intermediate_shard| {
366 (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
367 move |leaf_shard_index| {
368 ShardIndex(
369 (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
370 )
371 },
372 )
373 })
374 }
375
376 #[inline]
378 pub fn derive_shard_index(
379 &self,
380 public_key_hash: &Blake3Hash,
381 shard_commitments_root: &ShardCommitmentHash,
382 shard_membership_entropy: &ShardMembershipEntropy,
383 history_size: HistorySize,
384 ) -> ShardIndex {
385 let hash = single_block_keyed_hash(public_key_hash, &{
386 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
387 + ShardMembershipEntropy::SIZE
388 + HistorySize::SIZE as usize];
389 bytes_to_hash[..ShardCommitmentHash::SIZE]
390 .copy_from_slice(shard_commitments_root.as_bytes());
391 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
392 .copy_from_slice(shard_membership_entropy.as_bytes());
393 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
394 .copy_from_slice(history_size.as_bytes());
395 bytes_to_hash
396 })
397 .expect("Input is smaller than block size; qed");
398 let shard_index_offset =
401 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
402
403 self.iter_leaf_shards()
404 .nth(shard_index_offset as usize)
405 .unwrap_or(ShardIndex::BEACON_CHAIN)
406 }
407
408 #[inline]
412 pub fn derive_shard_commitment_index(
413 &self,
414 public_key_hash: &Blake3Hash,
415 shard_commitments_root: &ShardCommitmentHash,
416 shard_membership_entropy: &ShardMembershipEntropy,
417 history_size: HistorySize,
418 ) -> u32 {
419 let hash = single_block_keyed_hash(public_key_hash, &{
420 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
421 + ShardMembershipEntropy::SIZE
422 + HistorySize::SIZE as usize];
423 bytes_to_hash[..ShardCommitmentHash::SIZE]
424 .copy_from_slice(shard_commitments_root.as_bytes());
425 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
426 .copy_from_slice(shard_membership_entropy.as_bytes());
427 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
428 .copy_from_slice(history_size.as_bytes());
429 bytes_to_hash
430 })
431 .expect("Input is smaller than block size; qed");
432 const {
433 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
434 }
435 u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
436 % SolutionShardCommitment::NUM_LEAVES as u32
437 }
438
439 #[inline]
442 pub fn derive_shard_index_and_shard_commitment_index(
443 &self,
444 public_key_hash: &Blake3Hash,
445 shard_commitments_root: &ShardCommitmentHash,
446 shard_membership_entropy: &ShardMembershipEntropy,
447 history_size: HistorySize,
448 ) -> (ShardIndex, u32) {
449 let hash = single_block_keyed_hash(public_key_hash, &{
450 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
451 + ShardMembershipEntropy::SIZE
452 + HistorySize::SIZE as usize];
453 bytes_to_hash[..ShardCommitmentHash::SIZE]
454 .copy_from_slice(shard_commitments_root.as_bytes());
455 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
456 .copy_from_slice(shard_membership_entropy.as_bytes());
457 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
458 .copy_from_slice(history_size.as_bytes());
459 bytes_to_hash
460 })
461 .expect("Input is smaller than block size; qed");
462
463 let shard_index_offset =
466 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
467
468 let shard_index = self
469 .iter_leaf_shards()
470 .nth(shard_index_offset as usize)
471 .unwrap_or(ShardIndex::BEACON_CHAIN);
472
473 const {
474 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
475 }
476 let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
477 % SolutionShardCommitment::NUM_LEAVES as u32;
478
479 (shard_index, shard_commitment_index)
480 }
481}