1use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52 BeaconChain,
54 IntermediateShard,
56 LeafShard,
58 Phantom,
60}
61
62impl ShardKind {
63 #[inline(always)]
67 pub fn to_real(self) -> Option<RealShardKind> {
68 match self {
69 ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70 ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71 ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72 ShardKind::Phantom => None,
73 }
74 }
75}
76
77#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80 BeaconChain,
82 IntermediateShard,
84 LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89 #[inline(always)]
90 fn from(shard_kind: RealShardKind) -> Self {
91 match shard_kind {
92 RealShardKind::BeaconChain => ShardKind::BeaconChain,
93 RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94 RealShardKind::LeafShard => ShardKind::LeafShard,
95 }
96 }
97}
98
99#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl const Default for ShardIndex {
107 #[inline(always)]
108 fn default() -> Self {
109 Self::BEACON_CHAIN
110 }
111}
112
113impl const From<ShardIndex> for u32 {
114 #[inline(always)]
115 fn from(shard_index: ShardIndex) -> Self {
116 shard_index.0
117 }
118}
119
120impl ShardIndex {
121 pub const BEACON_CHAIN: Self = Self(0);
123 pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
125 pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
127 pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
129 NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
130
131 #[inline(always)]
137 pub const fn new(shard_index: u32) -> Option<Self> {
138 if shard_index > Self::MAX_SHARD_INDEX {
139 return None;
140 }
141
142 Some(Self(shard_index))
143 }
144
145 #[inline(always)]
147 pub const fn is_beacon_chain(&self) -> bool {
148 self.0 == Self::BEACON_CHAIN.0
149 }
150
151 #[inline(always)]
153 pub const fn is_intermediate_shard(&self) -> bool {
154 self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
155 }
156
157 #[inline(always)]
159 pub const fn is_leaf_shard(&self) -> bool {
160 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
161 return false;
162 }
163
164 self.0 & INTERMEDIATE_SHARD_MASK != 0
165 }
166
167 #[inline(always)]
169 pub const fn is_real(&self) -> bool {
170 !self.is_phantom_shard()
171 }
172
173 #[inline(always)]
175 pub const fn is_phantom_shard(&self) -> bool {
176 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
177 return false;
178 }
179
180 self.0 & INTERMEDIATE_SHARD_MASK == 0
181 }
182
183 #[inline(always)]
185 pub const fn is_child_of(self, parent: Self) -> bool {
186 match self.shard_kind() {
187 Some(ShardKind::BeaconChain) => false,
188 Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
189 Some(ShardKind::LeafShard) => {
190 self.0 & INTERMEDIATE_SHARD_MASK == parent.0
192 }
193 None => false,
194 }
195 }
196
197 #[inline(always)]
199 pub const fn parent_shard(self) -> Option<ShardIndex> {
200 match self.shard_kind()? {
201 ShardKind::BeaconChain => None,
202 ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
203 ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
204 }
205 }
206
207 #[inline(always)]
209 pub const fn shard_kind(&self) -> Option<ShardKind> {
210 if self.0 == Self::BEACON_CHAIN.0 {
211 Some(ShardKind::BeaconChain)
212 } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
213 && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
214 {
215 Some(ShardKind::IntermediateShard)
216 } else if self.0 > Self::MAX_SHARD_INDEX {
217 None
218 } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
219 Some(ShardKind::Phantom)
221 } else {
222 Some(ShardKind::LeafShard)
223 }
224 }
225}
226
227#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
231#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233#[repr(C)]
234pub struct NumShardsUnchecked {
235 pub intermediate_shards: u16,
237 pub leaf_shards_per_intermediate_shard: u16,
239}
240
241impl From<NumShards> for NumShardsUnchecked {
242 fn from(value: NumShards) -> Self {
243 Self {
244 intermediate_shards: value.intermediate_shards.get(),
245 leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
246 }
247 }
248}
249
250#[derive(Debug, Copy, Clone, Eq, PartialEq)]
252#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
253#[cfg_attr(feature = "serde", derive(Serialize))]
254pub struct NumShards {
255 intermediate_shards: NonZeroU16,
257 leaf_shards_per_intermediate_shard: NonZeroU16,
259}
260
261#[cfg(feature = "scale-codec")]
262impl Decode for NumShards {
263 fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
264 let intermediate_shards = Decode::decode(input)
265 .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
266 let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
267 error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
268 })?;
269
270 Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
271 .ok_or_else(|| "Invalid `NumShards`".into())
272 }
273}
274
275#[cfg(feature = "serde")]
276impl<'de> Deserialize<'de> for NumShards {
277 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
278 where
279 D: Deserializer<'de>,
280 {
281 #[derive(Deserialize)]
282 struct NumShards {
283 intermediate_shards: NonZeroU16,
284 leaf_shards_per_intermediate_shard: NonZeroU16,
285 }
286
287 let num_shards_inner = NumShards::deserialize(deserializer)?;
288
289 Self::new(
290 num_shards_inner.intermediate_shards,
291 num_shards_inner.leaf_shards_per_intermediate_shard,
292 )
293 .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
294 }
295}
296
297impl TryFrom<NumShardsUnchecked> for NumShards {
298 type Error = ();
299
300 fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
301 Self::new(
302 NonZeroU16::new(value.intermediate_shards).ok_or(())?,
303 NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
304 )
305 .ok_or(())
306 }
307}
308
309impl NumShards {
310 #[inline(always)]
317 pub const fn new(
318 intermediate_shards: NonZeroU16,
319 leaf_shards_per_intermediate_shard: NonZeroU16,
320 ) -> Option<Self> {
321 if intermediate_shards.get()
322 > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
323 {
324 return None;
325 }
326
327 let num_shards = Self {
328 intermediate_shards,
329 leaf_shards_per_intermediate_shard,
330 };
331
332 if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
333 return None;
334 }
335
336 Some(num_shards)
337 }
338
339 #[inline(always)]
341 pub const fn intermediate_shards(self) -> NonZeroU16 {
342 self.intermediate_shards
343 }
344 #[inline(always)]
346 pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
347 self.leaf_shards_per_intermediate_shard
348 }
349
350 #[inline(always)]
352 pub const fn leaf_shards(&self) -> NonZeroU32 {
353 NonZeroU32::new(
354 self.intermediate_shards.get() as u32
355 * self.leaf_shards_per_intermediate_shard.get() as u32,
356 )
357 .expect("Not zero; qed")
358 }
359
360 #[inline(always)]
362 pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
363 INTERMEDIATE_SHARDS_RANGE
364 .take(usize::from(self.intermediate_shards.get()))
365 .map(ShardIndex)
366 }
367
368 #[inline(always)]
370 pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
371 self.iter_intermediate_shards()
372 .flat_map(|intermediate_shard| {
373 (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
374 move |leaf_shard_index| {
375 ShardIndex(
376 (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
377 )
378 },
379 )
380 })
381 }
382
383 #[inline]
385 pub fn derive_shard_index(
386 &self,
387 public_key_hash: &Blake3Hash,
388 shard_commitments_root: &ShardCommitmentHash,
389 shard_membership_entropy: &ShardMembershipEntropy,
390 history_size: HistorySize,
391 ) -> ShardIndex {
392 let hash = single_block_keyed_hash(public_key_hash, &{
393 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
394 + ShardMembershipEntropy::SIZE
395 + HistorySize::SIZE as usize];
396 bytes_to_hash[..ShardCommitmentHash::SIZE]
397 .copy_from_slice(shard_commitments_root.as_bytes());
398 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
399 .copy_from_slice(shard_membership_entropy.as_bytes());
400 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
401 .copy_from_slice(history_size.as_bytes());
402 bytes_to_hash
403 })
404 .expect("Input is smaller than block size; qed");
405 let shard_index_offset =
408 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
409
410 self.iter_leaf_shards()
411 .nth(shard_index_offset as usize)
412 .unwrap_or(ShardIndex::BEACON_CHAIN)
413 }
414
415 #[inline]
419 pub fn derive_shard_commitment_index(
420 &self,
421 public_key_hash: &Blake3Hash,
422 shard_commitments_root: &ShardCommitmentHash,
423 shard_membership_entropy: &ShardMembershipEntropy,
424 history_size: HistorySize,
425 ) -> u32 {
426 let hash = single_block_keyed_hash(public_key_hash, &{
427 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
428 + ShardMembershipEntropy::SIZE
429 + HistorySize::SIZE as usize];
430 bytes_to_hash[..ShardCommitmentHash::SIZE]
431 .copy_from_slice(shard_commitments_root.as_bytes());
432 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
433 .copy_from_slice(shard_membership_entropy.as_bytes());
434 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
435 .copy_from_slice(history_size.as_bytes());
436 bytes_to_hash
437 })
438 .expect("Input is smaller than block size; qed");
439 const {
440 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
441 }
442 u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
443 % SolutionShardCommitment::NUM_LEAVES as u32
444 }
445
446 #[inline]
449 pub fn derive_shard_index_and_shard_commitment_index(
450 &self,
451 public_key_hash: &Blake3Hash,
452 shard_commitments_root: &ShardCommitmentHash,
453 shard_membership_entropy: &ShardMembershipEntropy,
454 history_size: HistorySize,
455 ) -> (ShardIndex, u32) {
456 let hash = single_block_keyed_hash(public_key_hash, &{
457 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
458 + ShardMembershipEntropy::SIZE
459 + HistorySize::SIZE as usize];
460 bytes_to_hash[..ShardCommitmentHash::SIZE]
461 .copy_from_slice(shard_commitments_root.as_bytes());
462 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
463 .copy_from_slice(shard_membership_entropy.as_bytes());
464 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
465 .copy_from_slice(history_size.as_bytes());
466 bytes_to_hash
467 })
468 .expect("Input is smaller than block size; qed");
469
470 let shard_index_offset =
473 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
474
475 let shard_index = self
476 .iter_leaf_shards()
477 .nth(shard_index_offset as usize)
478 .unwrap_or(ShardIndex::BEACON_CHAIN);
479
480 const {
481 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
482 }
483 let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
484 % SolutionShardCommitment::NUM_LEAVES as u32;
485
486 (shard_index, shard_commitment_index)
487 }
488}