1#![feature(maybe_uninit_slice, trusted_len)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec;
7use alloc::vec::Vec;
8use core::iter::TrustedLen;
9use core::mem::MaybeUninit;
10use reed_solomon_simd::Error;
11use reed_solomon_simd::engine::DefaultEngine;
12use reed_solomon_simd::rate::{HighRateDecoder, HighRateEncoder, RateDecoder, RateEncoder};
13
14#[derive(Debug, Clone, PartialEq, thiserror::Error)]
16pub enum ErasureCodingError {
17 #[error("Decoder error: {0}")]
19 DecoderError(#[from] Error),
20 #[error("Ignored source shard {index}")]
22 IgnoredSourceShard {
23 index: usize,
25 },
26 #[error("Wrong source shard byte length: expected {expected}, actual {actual}")]
28 WrongSourceShardByteLength { expected: usize, actual: usize },
29 #[error("Wrong parity shard byte length: expected {expected}, actual {actual}")]
31 WrongParityShardByteLength { expected: usize, actual: usize },
32}
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq)]
36pub enum RecoveryShardState<PresentShard, MissingShard> {
37 Present(PresentShard),
39 MissingRecover(MissingShard),
41 MissingIgnore,
46}
47
48#[derive(Debug, Clone)]
52pub struct ErasureCoding {}
53
54impl Default for ErasureCoding {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl ErasureCoding {
61 pub fn new() -> Self {
63 Self {}
64 }
65
66 pub fn extend<'a, SourceIter, ParityIter, SourceBytes, ParityBytes>(
68 &self,
69 source: SourceIter,
70 parity: ParityIter,
71 ) -> Result<(), ErasureCodingError>
72 where
73 SourceIter: TrustedLen<Item = SourceBytes>,
74 ParityIter: TrustedLen<Item = ParityBytes>,
75 SourceBytes: AsRef<[u8]> + 'a,
76 ParityBytes: AsMut<[u8]> + 'a,
77 {
78 let mut source = source.peekable();
79 let shard_byte_len = source
80 .peek()
81 .map(|shard| shard.as_ref().len())
82 .unwrap_or_default();
83
84 let mut encoder = HighRateEncoder::new(
85 source.size_hint().0,
86 parity.size_hint().0,
87 shard_byte_len,
88 DefaultEngine::new(),
89 None,
90 )?;
91
92 for shard in source {
93 encoder.add_original_shard(shard)?;
94 }
95
96 let result = encoder.encode()?;
97
98 for (input, mut output) in result.recovery_iter().zip(parity) {
99 let output = output.as_mut();
100 if output.len() != shard_byte_len {
101 return Err(ErasureCodingError::WrongParityShardByteLength {
102 expected: shard_byte_len,
103 actual: output.len(),
104 });
105 }
106 output.copy_from_slice(input);
107 }
108
109 Ok(())
110 }
111
112 pub fn recover<'a, SourceIter, ParityIter>(
114 &self,
115 source: SourceIter,
116 parity: ParityIter,
117 ) -> Result<(), ErasureCodingError>
118 where
119 SourceIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
120 ParityIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
121 {
122 let num_source = source.size_hint().0;
123 let num_parity = parity.size_hint().0;
124 let mut source = source.enumerate().peekable();
125 let mut parity = parity.enumerate().peekable();
126 let mut shard_byte_len = 0;
127
128 while let Some((_, shard)) = source.peek_mut() {
129 match shard {
130 RecoveryShardState::Present(shard_bytes) => {
131 shard_byte_len = shard_bytes.len();
132 break;
133 }
134 RecoveryShardState::MissingRecover(shard_bytes) => {
135 shard_byte_len = shard_bytes.len();
136 break;
137 }
138 RecoveryShardState::MissingIgnore => {
139 source.next();
141 }
142 }
143 }
144 if shard_byte_len == 0 {
145 while let Some((_, shard)) = parity.peek_mut() {
146 match shard {
147 RecoveryShardState::Present(shard_bytes) => {
148 shard_byte_len = shard_bytes.len();
149 break;
150 }
151 RecoveryShardState::MissingRecover(shard_bytes) => {
152 shard_byte_len = shard_bytes.len();
153 break;
154 }
155 RecoveryShardState::MissingIgnore => {
156 parity.next();
158 }
159 }
160 }
161 }
162
163 let mut all_source_shards = vec![MaybeUninit::uninit(); num_source];
164 let mut parity_shards_to_recover = Vec::new();
165
166 {
167 let mut decoder = HighRateDecoder::new(
168 num_source,
169 num_parity,
170 shard_byte_len,
171 DefaultEngine::new(),
172 None,
173 )?;
174
175 let mut source_shards_to_recover = Vec::new();
176 for (index, shard) in source {
177 match shard {
178 RecoveryShardState::Present(shard_bytes) => {
179 all_source_shards[index].write(shard_bytes);
180 decoder.add_original_shard(index, shard_bytes)?;
181 }
182 RecoveryShardState::MissingRecover(shard_bytes) => {
183 source_shards_to_recover.push((index, shard_bytes));
184 }
185 RecoveryShardState::MissingIgnore => {
186 return Err(ErasureCodingError::IgnoredSourceShard { index });
187 }
188 }
189 }
190
191 for (index, shard) in parity {
192 match shard {
193 RecoveryShardState::Present(shard_bytes) => {
194 decoder.add_recovery_shard(index, shard_bytes)?;
195 }
196 RecoveryShardState::MissingRecover(shard_bytes) => {
197 parity_shards_to_recover.push((index, shard_bytes));
198 }
199 RecoveryShardState::MissingIgnore => {}
200 }
201 }
202
203 let result = decoder.decode()?;
204
205 for (index, output) in source_shards_to_recover {
206 if output.len() != shard_byte_len {
207 return Err(ErasureCodingError::WrongSourceShardByteLength {
208 expected: shard_byte_len,
209 actual: output.len(),
210 });
211 }
212 let shard = result
213 .restored_original(index)
214 .expect("Always corresponds to a missing original shard; qed");
215 output.copy_from_slice(shard);
216 all_source_shards[index].write(output);
217 }
218 }
219
220 if !parity_shards_to_recover.is_empty() {
221 let all_source_shards = unsafe { all_source_shards.assume_init_ref() };
223
224 let mut encoder = HighRateEncoder::new(
225 num_source,
226 num_parity,
227 shard_byte_len,
228 DefaultEngine::new(),
229 None,
230 )?;
231
232 for shard in all_source_shards {
233 encoder.add_original_shard(shard)?;
234 }
235
236 let result = encoder.encode()?;
237
238 for (index, output) in parity_shards_to_recover {
239 if output.len() != shard_byte_len {
240 return Err(ErasureCodingError::WrongParityShardByteLength {
241 expected: shard_byte_len,
242 actual: output.len(),
243 });
244 }
245 output.copy_from_slice(
246 result
247 .recovery(index)
248 .expect("Always corresponds to a missing parity shard; qed"),
249 );
250 }
251 }
252
253 Ok(())
254 }
255}