1use crate::hashes::Blake3Hash;
4use crate::nano_u256::NanoU256;
5use crate::segments::HistorySize;
6use crate::solutions::{ShardCommitmentHash, ShardMembershipEntropy, SolutionShardCommitment};
7use ab_blake3::single_block_keyed_hash;
8use ab_io_type::trivial_type::TrivialType;
9use core::num::{NonZeroU16, NonZeroU32, NonZeroU128};
10use core::ops::RangeInclusive;
11use derive_more::Display;
12#[cfg(feature = "scale-codec")]
13use parity_scale_codec::{Decode, Encode, Input, MaxEncodedLen};
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Deserializer, Serialize};
16
17const INTERMEDIATE_SHARDS_RANGE: RangeInclusive<u32> = 1..=1023;
18const INTERMEDIATE_SHARD_BITS: u32 = 10;
19const INTERMEDIATE_SHARD_MASK: u32 = u32::MAX >> (u32::BITS - INTERMEDIATE_SHARD_BITS);
20
21#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
51pub enum ShardKind {
52 BeaconChain,
54 IntermediateShard,
56 LeafShard,
58 Phantom,
60}
61
62impl ShardKind {
63 #[inline(always)]
67 pub fn to_real(self) -> Option<RealShardKind> {
68 match self {
69 ShardKind::BeaconChain => Some(RealShardKind::BeaconChain),
70 ShardKind::IntermediateShard => Some(RealShardKind::IntermediateShard),
71 ShardKind::LeafShard => Some(RealShardKind::LeafShard),
72 ShardKind::Phantom => None,
73 }
74 }
75}
76
77#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
79pub enum RealShardKind {
80 BeaconChain,
82 IntermediateShard,
84 LeafShard,
86}
87
88impl From<RealShardKind> for ShardKind {
89 #[inline(always)]
90 fn from(shard_kind: RealShardKind) -> Self {
91 match shard_kind {
92 RealShardKind::BeaconChain => ShardKind::BeaconChain,
93 RealShardKind::IntermediateShard => ShardKind::IntermediateShard,
94 RealShardKind::LeafShard => ShardKind::LeafShard,
95 }
96 }
97}
98
99#[derive(Debug, Display, Copy, Clone, Hash, Ord, PartialOrd, Eq, PartialEq, TrivialType)]
101#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103#[repr(C)]
104pub struct ShardIndex(u32);
105
106impl ShardIndex {
107 pub const BEACON_CHAIN: Self = Self(0);
109 pub const MAX_SHARD_INDEX: u32 = Self::MAX_SHARDS.get() - 1;
111 pub const MAX_SHARDS: NonZeroU32 = NonZeroU32::new(2u32.pow(20)).expect("Not zero; qed");
113 pub const MAX_ADDRESSES_PER_SHARD: NonZeroU128 =
115 NonZeroU128::new(2u128.pow(108)).expect("Not zero; qed");
116
117 #[inline(always)]
123 pub const fn new(shard_index: u32) -> Option<Self> {
124 if shard_index > Self::MAX_SHARD_INDEX {
125 return None;
126 }
127
128 Some(Self(shard_index))
129 }
130
131 #[inline(always)]
136 pub const fn as_u32(self) -> u32 {
137 self.0
138 }
139
140 #[inline(always)]
142 pub const fn is_beacon_chain(&self) -> bool {
143 self.0 == Self::BEACON_CHAIN.0
144 }
145
146 #[inline(always)]
148 pub const fn is_intermediate_shard(&self) -> bool {
149 self.0 >= *INTERMEDIATE_SHARDS_RANGE.start() && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
150 }
151
152 #[inline(always)]
154 pub const fn is_leaf_shard(&self) -> bool {
155 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
156 return false;
157 }
158
159 self.0 & INTERMEDIATE_SHARD_MASK != 0
160 }
161
162 #[inline(always)]
164 pub const fn is_real(&self) -> bool {
165 !self.is_phantom_shard()
166 }
167
168 #[inline(always)]
170 pub const fn is_phantom_shard(&self) -> bool {
171 if self.0 <= *INTERMEDIATE_SHARDS_RANGE.end() || self.0 > Self::MAX_SHARD_INDEX {
172 return false;
173 }
174
175 self.0 & INTERMEDIATE_SHARD_MASK == 0
176 }
177
178 #[inline(always)]
180 pub const fn is_child_of(self, parent: Self) -> bool {
181 match self.shard_kind() {
182 Some(ShardKind::BeaconChain) => false,
183 Some(ShardKind::IntermediateShard | ShardKind::Phantom) => parent.is_beacon_chain(),
184 Some(ShardKind::LeafShard) => {
185 self.0 & INTERMEDIATE_SHARD_MASK == parent.0
187 }
188 None => false,
189 }
190 }
191
192 #[inline(always)]
194 pub const fn parent_shard(self) -> Option<ShardIndex> {
195 match self.shard_kind()? {
196 ShardKind::BeaconChain => None,
197 ShardKind::IntermediateShard | ShardKind::Phantom => Some(ShardIndex::BEACON_CHAIN),
198 ShardKind::LeafShard => Some(Self(self.0 & INTERMEDIATE_SHARD_MASK)),
199 }
200 }
201
202 #[inline(always)]
204 pub const fn shard_kind(&self) -> Option<ShardKind> {
205 if self.0 == Self::BEACON_CHAIN.0 {
206 Some(ShardKind::BeaconChain)
207 } else if self.0 >= *INTERMEDIATE_SHARDS_RANGE.start()
208 && self.0 <= *INTERMEDIATE_SHARDS_RANGE.end()
209 {
210 Some(ShardKind::IntermediateShard)
211 } else if self.0 > Self::MAX_SHARD_INDEX {
212 None
213 } else if self.0 & INTERMEDIATE_SHARD_MASK == 0 {
214 Some(ShardKind::Phantom)
216 } else {
217 Some(ShardKind::LeafShard)
218 }
219 }
220}
221
222#[derive(Debug, Copy, Clone, Eq, PartialEq, TrivialType)]
226#[cfg_attr(feature = "scale-codec", derive(Encode, Decode, MaxEncodedLen))]
227#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
228#[repr(C)]
229pub struct NumShardsUnchecked {
230 pub intermediate_shards: u16,
232 pub leaf_shards_per_intermediate_shard: u16,
234}
235
236impl From<NumShards> for NumShardsUnchecked {
237 fn from(value: NumShards) -> Self {
238 Self {
239 intermediate_shards: value.intermediate_shards.get(),
240 leaf_shards_per_intermediate_shard: value.leaf_shards_per_intermediate_shard.get(),
241 }
242 }
243}
244
245#[derive(Debug, Copy, Clone, Eq, PartialEq)]
247#[cfg_attr(feature = "scale-codec", derive(Encode, MaxEncodedLen))]
248#[cfg_attr(feature = "serde", derive(Serialize))]
249pub struct NumShards {
250 intermediate_shards: NonZeroU16,
252 leaf_shards_per_intermediate_shard: NonZeroU16,
254}
255
256#[cfg(feature = "scale-codec")]
257impl Decode for NumShards {
258 fn decode<I: Input>(input: &mut I) -> Result<Self, parity_scale_codec::Error> {
259 let intermediate_shards = Decode::decode(input)
260 .map_err(|error| error.chain("Could not decode `NumShards::intermediate_shards`"))?;
261 let leaf_shards_per_intermediate_shard = Decode::decode(input).map_err(|error| {
262 error.chain("Could not decode `NumShards::leaf_shards_per_intermediate_shard`")
263 })?;
264
265 Self::new(intermediate_shards, leaf_shards_per_intermediate_shard)
266 .ok_or_else(|| "Invalid `NumShards`".into())
267 }
268}
269
270#[cfg(feature = "serde")]
271impl<'de> Deserialize<'de> for NumShards {
272 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273 where
274 D: Deserializer<'de>,
275 {
276 #[derive(Deserialize)]
277 struct NumShards {
278 intermediate_shards: NonZeroU16,
279 leaf_shards_per_intermediate_shard: NonZeroU16,
280 }
281
282 let num_shards_inner = NumShards::deserialize(deserializer)?;
283
284 Self::new(
285 num_shards_inner.intermediate_shards,
286 num_shards_inner.leaf_shards_per_intermediate_shard,
287 )
288 .ok_or_else(|| serde::de::Error::custom("Invalid `NumShards`"))
289 }
290}
291
292impl TryFrom<NumShardsUnchecked> for NumShards {
293 type Error = ();
294
295 fn try_from(value: NumShardsUnchecked) -> Result<Self, Self::Error> {
296 Self::new(
297 NonZeroU16::new(value.intermediate_shards).ok_or(())?,
298 NonZeroU16::new(value.leaf_shards_per_intermediate_shard).ok_or(())?,
299 )
300 .ok_or(())
301 }
302}
303
304impl NumShards {
305 #[inline(always)]
312 pub const fn new(
313 intermediate_shards: NonZeroU16,
314 leaf_shards_per_intermediate_shard: NonZeroU16,
315 ) -> Option<Self> {
316 if intermediate_shards.get()
317 > (*INTERMEDIATE_SHARDS_RANGE.end() - *INTERMEDIATE_SHARDS_RANGE.start() + 1) as u16
318 {
319 return None;
320 }
321
322 let num_shards = Self {
323 intermediate_shards,
324 leaf_shards_per_intermediate_shard,
325 };
326
327 if num_shards.leaf_shards() > ShardIndex::MAX_SHARDS {
328 return None;
329 }
330
331 Some(num_shards)
332 }
333
334 #[inline(always)]
336 pub const fn intermediate_shards(self) -> NonZeroU16 {
337 self.intermediate_shards
338 }
339 #[inline(always)]
341 pub const fn leaf_shards_per_intermediate_shard(self) -> NonZeroU16 {
342 self.leaf_shards_per_intermediate_shard
343 }
344
345 #[inline(always)]
347 pub const fn leaf_shards(&self) -> NonZeroU32 {
348 NonZeroU32::new(
349 self.intermediate_shards.get() as u32
350 * self.leaf_shards_per_intermediate_shard.get() as u32,
351 )
352 .expect("Not zero; qed")
353 }
354
355 #[inline(always)]
357 pub fn iter_intermediate_shards(&self) -> impl Iterator<Item = ShardIndex> {
358 INTERMEDIATE_SHARDS_RANGE
359 .take(usize::from(self.intermediate_shards.get()))
360 .map(ShardIndex)
361 }
362
363 #[inline(always)]
365 pub fn iter_leaf_shards(&self) -> impl Iterator<Item = ShardIndex> {
366 self.iter_intermediate_shards()
367 .flat_map(|intermediate_shard| {
368 (0..u32::from(self.leaf_shards_per_intermediate_shard.get())).map(
369 move |leaf_shard_index| {
370 ShardIndex(
371 (leaf_shard_index << INTERMEDIATE_SHARD_BITS) | intermediate_shard.0,
372 )
373 },
374 )
375 })
376 }
377
378 #[inline]
380 pub fn derive_shard_index(
381 &self,
382 public_key_hash: &Blake3Hash,
383 shard_commitments_root: &ShardCommitmentHash,
384 shard_membership_entropy: &ShardMembershipEntropy,
385 history_size: HistorySize,
386 ) -> ShardIndex {
387 let hash = single_block_keyed_hash(public_key_hash, &{
388 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
389 + ShardMembershipEntropy::SIZE
390 + HistorySize::SIZE as usize];
391 bytes_to_hash[..ShardCommitmentHash::SIZE]
392 .copy_from_slice(shard_commitments_root.as_bytes());
393 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
394 .copy_from_slice(shard_membership_entropy.as_bytes());
395 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
396 .copy_from_slice(history_size.as_bytes());
397 bytes_to_hash
398 })
399 .expect("Input is smaller than block size; qed");
400 let shard_index_offset =
403 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
404
405 self.iter_leaf_shards()
406 .nth(shard_index_offset as usize)
407 .unwrap_or(ShardIndex::BEACON_CHAIN)
408 }
409
410 #[inline]
414 pub fn derive_shard_commitment_index(
415 &self,
416 public_key_hash: &Blake3Hash,
417 shard_commitments_root: &ShardCommitmentHash,
418 shard_membership_entropy: &ShardMembershipEntropy,
419 history_size: HistorySize,
420 ) -> u32 {
421 let hash = single_block_keyed_hash(public_key_hash, &{
422 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
423 + ShardMembershipEntropy::SIZE
424 + HistorySize::SIZE as usize];
425 bytes_to_hash[..ShardCommitmentHash::SIZE]
426 .copy_from_slice(shard_commitments_root.as_bytes());
427 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
428 .copy_from_slice(shard_membership_entropy.as_bytes());
429 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
430 .copy_from_slice(history_size.as_bytes());
431 bytes_to_hash
432 })
433 .expect("Input is smaller than block size; qed");
434 const {
435 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
436 }
437 u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
438 % SolutionShardCommitment::NUM_LEAVES as u32
439 }
440
441 #[inline]
444 pub fn derive_shard_index_and_shard_commitment_index(
445 &self,
446 public_key_hash: &Blake3Hash,
447 shard_commitments_root: &ShardCommitmentHash,
448 shard_membership_entropy: &ShardMembershipEntropy,
449 history_size: HistorySize,
450 ) -> (ShardIndex, u32) {
451 let hash = single_block_keyed_hash(public_key_hash, &{
452 let mut bytes_to_hash = [0u8; ShardCommitmentHash::SIZE
453 + ShardMembershipEntropy::SIZE
454 + HistorySize::SIZE as usize];
455 bytes_to_hash[..ShardCommitmentHash::SIZE]
456 .copy_from_slice(shard_commitments_root.as_bytes());
457 bytes_to_hash[ShardCommitmentHash::SIZE..][..ShardMembershipEntropy::SIZE]
458 .copy_from_slice(shard_membership_entropy.as_bytes());
459 bytes_to_hash[ShardCommitmentHash::SIZE + ShardMembershipEntropy::SIZE..]
460 .copy_from_slice(history_size.as_bytes());
461 bytes_to_hash
462 })
463 .expect("Input is smaller than block size; qed");
464
465 let shard_index_offset =
468 NanoU256::from_le_bytes(hash) % u64::from(self.leaf_shards().get());
469
470 let shard_index = self
471 .iter_leaf_shards()
472 .nth(shard_index_offset as usize)
473 .unwrap_or(ShardIndex::BEACON_CHAIN);
474
475 const {
476 assert!(SolutionShardCommitment::NUM_LEAVES.is_power_of_two());
477 }
478 let shard_commitment_index = u32::from_le_bytes([hash[0], hash[1], hash[2], hash[3]])
479 % SolutionShardCommitment::NUM_LEAVES as u32;
480
481 (shard_index, shard_commitment_index)
482 }
483}