ab_erasure_coding/
lib.rs

1#![feature(maybe_uninit_slice, trusted_len)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec;
7use alloc::vec::Vec;
8use core::iter::TrustedLen;
9use core::mem::MaybeUninit;
10use reed_solomon_simd::Error;
11use reed_solomon_simd::engine::DefaultEngine;
12use reed_solomon_simd::rate::{HighRateDecoder, HighRateEncoder, RateDecoder, RateEncoder};
13
14/// Error that occurs when calling [`ErasureCoding::recover()`]
15#[derive(Debug, Clone, PartialEq, thiserror::Error)]
16pub enum ErasureCodingError {
17    /// Decoder error
18    #[error("Decoder error: {0}")]
19    DecoderError(#[from] Error),
20    /// Wrong source shard byte length
21    #[error("Wrong source shard byte length: expected {expected}, actual {actual}")]
22    WrongSourceShardByteLength { expected: usize, actual: usize },
23    /// Wrong parity shard byte length
24    #[error("Wrong parity shard byte length: expected {expected}, actual {actual}")]
25    WrongParityShardByteLength { expected: usize, actual: usize },
26}
27
28/// State of the shard for recovery
29#[derive(Debug, Copy, Clone, Eq, PartialEq)]
30pub enum RecoveryShardState<PresentShard, MissingShard> {
31    /// Shard is present and will be used for recovery
32    Present(PresentShard),
33    /// Shard is missing and needs to be recovered
34    MissingRecover(MissingShard),
35    /// Shard is missing and does not need to be recovered
36    MissingIgnore,
37}
38
39/// Erasure coding abstraction.
40///
41/// Supports creation of parity records and recovery of missing data.
42#[derive(Debug, Clone)]
43pub struct ErasureCoding {}
44
45impl Default for ErasureCoding {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl ErasureCoding {
52    /// Create new erasure coding instance
53    pub fn new() -> Self {
54        Self {}
55    }
56
57    /// Extend sources using erasure coding
58    pub fn extend<'a, SourceIter, ParityIter, SourceBytes, ParityBytes>(
59        &self,
60        source: SourceIter,
61        parity: ParityIter,
62    ) -> Result<(), ErasureCodingError>
63    where
64        SourceIter: TrustedLen<Item = SourceBytes>,
65        ParityIter: TrustedLen<Item = ParityBytes>,
66        SourceBytes: AsRef<[u8]> + 'a,
67        ParityBytes: AsMut<[u8]> + 'a,
68    {
69        let mut source = source.peekable();
70        let shard_byte_len = source
71            .peek()
72            .map(|shard| shard.as_ref().len())
73            .unwrap_or_default();
74
75        let mut encoder = HighRateEncoder::new(
76            source.size_hint().0,
77            parity.size_hint().0,
78            shard_byte_len,
79            DefaultEngine::new(),
80            None,
81        )?;
82
83        for shard in source {
84            encoder.add_original_shard(shard)?;
85        }
86
87        let result = encoder.encode()?;
88
89        for (input, mut output) in result.recovery_iter().zip(parity) {
90            let output = output.as_mut();
91            if output.len() != shard_byte_len {
92                return Err(ErasureCodingError::WrongParityShardByteLength {
93                    expected: shard_byte_len,
94                    actual: output.len(),
95                });
96            }
97            output.copy_from_slice(input);
98        }
99
100        Ok(())
101    }
102
103    /// Recover missing shards
104    pub fn recover<'a, SourceIter, ParityIter>(
105        &self,
106        source: SourceIter,
107        parity: ParityIter,
108    ) -> Result<(), ErasureCodingError>
109    where
110        SourceIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
111        ParityIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
112    {
113        let num_source = source.size_hint().0;
114        let num_parity = parity.size_hint().0;
115        let mut source = source.enumerate().peekable();
116        let mut parity = parity.enumerate().peekable();
117        let mut shard_byte_len = 0;
118
119        while let Some((_, shard)) = source.peek_mut() {
120            match shard {
121                RecoveryShardState::Present(shard_bytes) => {
122                    shard_byte_len = shard_bytes.len();
123                    break;
124                }
125                RecoveryShardState::MissingRecover(shard_bytes) => {
126                    shard_byte_len = shard_bytes.len();
127                    break;
128                }
129                RecoveryShardState::MissingIgnore => {
130                    // Skip, it is inconsequential here
131                    source.next();
132                }
133            }
134        }
135        if shard_byte_len == 0 {
136            while let Some((_, shard)) = parity.peek_mut() {
137                match shard {
138                    RecoveryShardState::Present(shard_bytes) => {
139                        shard_byte_len = shard_bytes.len();
140                        break;
141                    }
142                    RecoveryShardState::MissingRecover(shard_bytes) => {
143                        shard_byte_len = shard_bytes.len();
144                        break;
145                    }
146                    RecoveryShardState::MissingIgnore => {
147                        // Skip, it is inconsequential here
148                        parity.next();
149                    }
150                }
151            }
152        }
153
154        let mut all_source_shards = vec![MaybeUninit::uninit(); num_source];
155        let mut parity_shards_to_recover = Vec::new();
156
157        {
158            let mut decoder = HighRateDecoder::new(
159                num_source,
160                num_parity,
161                shard_byte_len,
162                DefaultEngine::new(),
163                None,
164            )?;
165
166            let mut source_shards_to_recover = Vec::new();
167            for (index, shard) in source {
168                match shard {
169                    RecoveryShardState::Present(shard_bytes) => {
170                        all_source_shards[index].write(shard_bytes);
171                        decoder.add_original_shard(index, shard_bytes)?;
172                    }
173                    RecoveryShardState::MissingRecover(shard_bytes) => {
174                        source_shards_to_recover.push((index, shard_bytes));
175                    }
176                    RecoveryShardState::MissingIgnore => {}
177                }
178            }
179
180            for (index, shard) in parity {
181                match shard {
182                    RecoveryShardState::Present(shard_bytes) => {
183                        decoder.add_recovery_shard(index, shard_bytes)?;
184                    }
185                    RecoveryShardState::MissingRecover(shard_bytes) => {
186                        parity_shards_to_recover.push((index, shard_bytes));
187                    }
188                    RecoveryShardState::MissingIgnore => {}
189                }
190            }
191
192            let result = decoder.decode()?;
193
194            for (index, output) in source_shards_to_recover {
195                if output.len() != shard_byte_len {
196                    return Err(ErasureCodingError::WrongSourceShardByteLength {
197                        expected: shard_byte_len,
198                        actual: output.len(),
199                    });
200                }
201                let shard = result
202                    .restored_original(index)
203                    .expect("Always corresponds to a missing original shard; qed");
204                output.copy_from_slice(shard);
205                all_source_shards[index].write(output);
206            }
207        }
208
209        if !parity_shards_to_recover.is_empty() {
210            let all_source_shards = unsafe { all_source_shards.assume_init_ref() };
211
212            let mut encoder = HighRateEncoder::new(
213                num_source,
214                num_parity,
215                shard_byte_len,
216                DefaultEngine::new(),
217                None,
218            )?;
219
220            for shard in all_source_shards {
221                encoder.add_original_shard(shard)?;
222            }
223
224            let result = encoder.encode()?;
225
226            for (index, output) in parity_shards_to_recover {
227                if output.len() != shard_byte_len {
228                    return Err(ErasureCodingError::WrongParityShardByteLength {
229                        expected: shard_byte_len,
230                        actual: output.len(),
231                    });
232                }
233                output.copy_from_slice(
234                    result
235                        .recovery(index)
236                        .expect("Always corresponds to a missing parity shard; qed"),
237                );
238            }
239        }
240
241        Ok(())
242    }
243}