ab_erasure_coding/
lib.rs

1#![feature(maybe_uninit_slice, trusted_len)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec;
7use alloc::vec::Vec;
8use core::iter::TrustedLen;
9use core::mem::MaybeUninit;
10use reed_solomon_simd::Error;
11use reed_solomon_simd::engine::DefaultEngine;
12use reed_solomon_simd::rate::{HighRateDecoder, HighRateEncoder, RateDecoder, RateEncoder};
13
14/// Error that occurs when calling [`ErasureCoding::recover()`]
15#[derive(Debug, Clone, PartialEq, thiserror::Error)]
16pub enum ErasureCodingError {
17    /// Decoder error
18    #[error("Decoder error: {0}")]
19    DecoderError(#[from] Error),
20    /// Ignored source shard
21    #[error("Ignored source shard {index}")]
22    IgnoredSourceShard {
23        /// Shard index
24        index: usize,
25    },
26    /// Wrong source shard byte length
27    #[error("Wrong source shard byte length: expected {expected}, actual {actual}")]
28    WrongSourceShardByteLength { expected: usize, actual: usize },
29    /// Wrong parity shard byte length
30    #[error("Wrong parity shard byte length: expected {expected}, actual {actual}")]
31    WrongParityShardByteLength { expected: usize, actual: usize },
32}
33
34/// State of the shard for recovery
35#[derive(Debug, Copy, Clone, Eq, PartialEq)]
36pub enum RecoveryShardState<PresentShard, MissingShard> {
37    /// Shard is present and will be used for recovery
38    Present(PresentShard),
39    /// Shard is missing and needs to be recovered
40    MissingRecover(MissingShard),
41    /// Shard is missing and does not need to be recovered.
42    ///
43    /// This is only allowed for parity shards, all source shards must always be present or
44    /// recovered.
45    MissingIgnore,
46}
47
48/// Erasure coding abstraction.
49///
50/// Supports creation of parity records and recovery of missing data.
51#[derive(Debug, Clone)]
52pub struct ErasureCoding {}
53
54impl Default for ErasureCoding {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl ErasureCoding {
61    /// Create new erasure coding instance
62    pub fn new() -> Self {
63        Self {}
64    }
65
66    /// Extend sources using erasure coding
67    pub fn extend<'a, SourceIter, ParityIter, SourceBytes, ParityBytes>(
68        &self,
69        source: SourceIter,
70        parity: ParityIter,
71    ) -> Result<(), ErasureCodingError>
72    where
73        SourceIter: TrustedLen<Item = SourceBytes>,
74        ParityIter: TrustedLen<Item = ParityBytes>,
75        SourceBytes: AsRef<[u8]> + 'a,
76        ParityBytes: AsMut<[u8]> + 'a,
77    {
78        let mut source = source.peekable();
79        let shard_byte_len = source
80            .peek()
81            .map(|shard| shard.as_ref().len())
82            .unwrap_or_default();
83
84        let mut encoder = HighRateEncoder::new(
85            source.size_hint().0,
86            parity.size_hint().0,
87            shard_byte_len,
88            DefaultEngine::new(),
89            None,
90        )?;
91
92        for shard in source {
93            encoder.add_original_shard(shard)?;
94        }
95
96        let result = encoder.encode()?;
97
98        for (input, mut output) in result.recovery_iter().zip(parity) {
99            let output = output.as_mut();
100            if output.len() != shard_byte_len {
101                return Err(ErasureCodingError::WrongParityShardByteLength {
102                    expected: shard_byte_len,
103                    actual: output.len(),
104                });
105            }
106            output.copy_from_slice(input);
107        }
108
109        Ok(())
110    }
111
112    /// Recover missing shards
113    pub fn recover<'a, SourceIter, ParityIter>(
114        &self,
115        source: SourceIter,
116        parity: ParityIter,
117    ) -> Result<(), ErasureCodingError>
118    where
119        SourceIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
120        ParityIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
121    {
122        let num_source = source.size_hint().0;
123        let num_parity = parity.size_hint().0;
124        let mut source = source.enumerate().peekable();
125        let mut parity = parity.enumerate().peekable();
126        let mut shard_byte_len = 0;
127
128        while let Some((_, shard)) = source.peek_mut() {
129            match shard {
130                RecoveryShardState::Present(shard_bytes) => {
131                    shard_byte_len = shard_bytes.len();
132                    break;
133                }
134                RecoveryShardState::MissingRecover(shard_bytes) => {
135                    shard_byte_len = shard_bytes.len();
136                    break;
137                }
138                RecoveryShardState::MissingIgnore => {
139                    // Skip, it is inconsequential here
140                    source.next();
141                }
142            }
143        }
144        if shard_byte_len == 0 {
145            while let Some((_, shard)) = parity.peek_mut() {
146                match shard {
147                    RecoveryShardState::Present(shard_bytes) => {
148                        shard_byte_len = shard_bytes.len();
149                        break;
150                    }
151                    RecoveryShardState::MissingRecover(shard_bytes) => {
152                        shard_byte_len = shard_bytes.len();
153                        break;
154                    }
155                    RecoveryShardState::MissingIgnore => {
156                        // Skip, it is inconsequential here
157                        parity.next();
158                    }
159                }
160            }
161        }
162
163        let mut all_source_shards = vec![MaybeUninit::uninit(); num_source];
164        let mut parity_shards_to_recover = Vec::new();
165
166        {
167            let mut decoder = HighRateDecoder::new(
168                num_source,
169                num_parity,
170                shard_byte_len,
171                DefaultEngine::new(),
172                None,
173            )?;
174
175            let mut source_shards_to_recover = Vec::new();
176            for (index, shard) in source {
177                match shard {
178                    RecoveryShardState::Present(shard_bytes) => {
179                        all_source_shards[index].write(shard_bytes);
180                        decoder.add_original_shard(index, shard_bytes)?;
181                    }
182                    RecoveryShardState::MissingRecover(shard_bytes) => {
183                        source_shards_to_recover.push((index, shard_bytes));
184                    }
185                    RecoveryShardState::MissingIgnore => {
186                        return Err(ErasureCodingError::IgnoredSourceShard { index });
187                    }
188                }
189            }
190
191            for (index, shard) in parity {
192                match shard {
193                    RecoveryShardState::Present(shard_bytes) => {
194                        decoder.add_recovery_shard(index, shard_bytes)?;
195                    }
196                    RecoveryShardState::MissingRecover(shard_bytes) => {
197                        parity_shards_to_recover.push((index, shard_bytes));
198                    }
199                    RecoveryShardState::MissingIgnore => {}
200                }
201            }
202
203            let result = decoder.decode()?;
204
205            for (index, output) in source_shards_to_recover {
206                if output.len() != shard_byte_len {
207                    return Err(ErasureCodingError::WrongSourceShardByteLength {
208                        expected: shard_byte_len,
209                        actual: output.len(),
210                    });
211                }
212                let shard = result
213                    .restored_original(index)
214                    .expect("Always corresponds to a missing original shard; qed");
215                output.copy_from_slice(shard);
216                all_source_shards[index].write(output);
217            }
218        }
219
220        if !parity_shards_to_recover.is_empty() {
221            // SAFETY: All `all_source_shards` are either initialized from the start or recovered
222            let all_source_shards = unsafe { all_source_shards.assume_init_ref() };
223
224            let mut encoder = HighRateEncoder::new(
225                num_source,
226                num_parity,
227                shard_byte_len,
228                DefaultEngine::new(),
229                None,
230            )?;
231
232            for shard in all_source_shards {
233                encoder.add_original_shard(shard)?;
234            }
235
236            let result = encoder.encode()?;
237
238            for (index, output) in parity_shards_to_recover {
239                if output.len() != shard_byte_len {
240                    return Err(ErasureCodingError::WrongParityShardByteLength {
241                        expected: shard_byte_len,
242                        actual: output.len(),
243                    });
244                }
245                output.copy_from_slice(
246                    result
247                        .recovery(index)
248                        .expect("Always corresponds to a missing parity shard; qed"),
249                );
250            }
251        }
252
253        Ok(())
254    }
255}