1#![feature(maybe_uninit_slice, trusted_len)]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::vec;
7use alloc::vec::Vec;
8use core::iter::TrustedLen;
9use core::mem::MaybeUninit;
10use reed_solomon_simd::Error;
11use reed_solomon_simd::engine::DefaultEngine;
12use reed_solomon_simd::rate::{HighRateDecoder, HighRateEncoder, RateDecoder, RateEncoder};
13
14#[derive(Debug, Clone, PartialEq, thiserror::Error)]
16pub enum ErasureCodingError {
17 #[error("Decoder error: {0}")]
19 DecoderError(#[from] Error),
20 #[error("Wrong source shard byte length: expected {expected}, actual {actual}")]
22 WrongSourceShardByteLength { expected: usize, actual: usize },
23 #[error("Wrong parity shard byte length: expected {expected}, actual {actual}")]
25 WrongParityShardByteLength { expected: usize, actual: usize },
26}
27
28#[derive(Debug, Copy, Clone, Eq, PartialEq)]
30pub enum RecoveryShardState<PresentShard, MissingShard> {
31 Present(PresentShard),
33 MissingRecover(MissingShard),
35 MissingIgnore,
37}
38
39#[derive(Debug, Clone)]
43pub struct ErasureCoding {}
44
45impl Default for ErasureCoding {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl ErasureCoding {
52 pub fn new() -> Self {
54 Self {}
55 }
56
57 pub fn extend<'a, SourceIter, ParityIter, SourceBytes, ParityBytes>(
59 &self,
60 source: SourceIter,
61 parity: ParityIter,
62 ) -> Result<(), ErasureCodingError>
63 where
64 SourceIter: TrustedLen<Item = SourceBytes>,
65 ParityIter: TrustedLen<Item = ParityBytes>,
66 SourceBytes: AsRef<[u8]> + 'a,
67 ParityBytes: AsMut<[u8]> + 'a,
68 {
69 let mut source = source.peekable();
70 let shard_byte_len = source
71 .peek()
72 .map(|shard| shard.as_ref().len())
73 .unwrap_or_default();
74
75 let mut encoder = HighRateEncoder::new(
76 source.size_hint().0,
77 parity.size_hint().0,
78 shard_byte_len,
79 DefaultEngine::new(),
80 None,
81 )?;
82
83 for shard in source {
84 encoder.add_original_shard(shard)?;
85 }
86
87 let result = encoder.encode()?;
88
89 for (input, mut output) in result.recovery_iter().zip(parity) {
90 let output = output.as_mut();
91 if output.len() != shard_byte_len {
92 return Err(ErasureCodingError::WrongParityShardByteLength {
93 expected: shard_byte_len,
94 actual: output.len(),
95 });
96 }
97 output.copy_from_slice(input);
98 }
99
100 Ok(())
101 }
102
103 pub fn recover<'a, SourceIter, ParityIter>(
105 &self,
106 source: SourceIter,
107 parity: ParityIter,
108 ) -> Result<(), ErasureCodingError>
109 where
110 SourceIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
111 ParityIter: TrustedLen<Item = RecoveryShardState<&'a [u8], &'a mut [u8]>>,
112 {
113 let num_source = source.size_hint().0;
114 let num_parity = parity.size_hint().0;
115 let mut source = source.enumerate().peekable();
116 let mut parity = parity.enumerate().peekable();
117 let mut shard_byte_len = 0;
118
119 while let Some((_, shard)) = source.peek_mut() {
120 match shard {
121 RecoveryShardState::Present(shard_bytes) => {
122 shard_byte_len = shard_bytes.len();
123 break;
124 }
125 RecoveryShardState::MissingRecover(shard_bytes) => {
126 shard_byte_len = shard_bytes.len();
127 break;
128 }
129 RecoveryShardState::MissingIgnore => {
130 source.next();
132 }
133 }
134 }
135 if shard_byte_len == 0 {
136 while let Some((_, shard)) = parity.peek_mut() {
137 match shard {
138 RecoveryShardState::Present(shard_bytes) => {
139 shard_byte_len = shard_bytes.len();
140 break;
141 }
142 RecoveryShardState::MissingRecover(shard_bytes) => {
143 shard_byte_len = shard_bytes.len();
144 break;
145 }
146 RecoveryShardState::MissingIgnore => {
147 parity.next();
149 }
150 }
151 }
152 }
153
154 let mut all_source_shards = vec![MaybeUninit::uninit(); num_source];
155 let mut parity_shards_to_recover = Vec::new();
156
157 {
158 let mut decoder = HighRateDecoder::new(
159 num_source,
160 num_parity,
161 shard_byte_len,
162 DefaultEngine::new(),
163 None,
164 )?;
165
166 let mut source_shards_to_recover = Vec::new();
167 for (index, shard) in source {
168 match shard {
169 RecoveryShardState::Present(shard_bytes) => {
170 all_source_shards[index].write(shard_bytes);
171 decoder.add_original_shard(index, shard_bytes)?;
172 }
173 RecoveryShardState::MissingRecover(shard_bytes) => {
174 source_shards_to_recover.push((index, shard_bytes));
175 }
176 RecoveryShardState::MissingIgnore => {}
177 }
178 }
179
180 for (index, shard) in parity {
181 match shard {
182 RecoveryShardState::Present(shard_bytes) => {
183 decoder.add_recovery_shard(index, shard_bytes)?;
184 }
185 RecoveryShardState::MissingRecover(shard_bytes) => {
186 parity_shards_to_recover.push((index, shard_bytes));
187 }
188 RecoveryShardState::MissingIgnore => {}
189 }
190 }
191
192 let result = decoder.decode()?;
193
194 for (index, output) in source_shards_to_recover {
195 if output.len() != shard_byte_len {
196 return Err(ErasureCodingError::WrongSourceShardByteLength {
197 expected: shard_byte_len,
198 actual: output.len(),
199 });
200 }
201 let shard = result
202 .restored_original(index)
203 .expect("Always corresponds to a missing original shard; qed");
204 output.copy_from_slice(shard);
205 all_source_shards[index].write(output);
206 }
207 }
208
209 if !parity_shards_to_recover.is_empty() {
210 let all_source_shards = unsafe { all_source_shards.assume_init_ref() };
211
212 let mut encoder = HighRateEncoder::new(
213 num_source,
214 num_parity,
215 shard_byte_len,
216 DefaultEngine::new(),
217 None,
218 )?;
219
220 for shard in all_source_shards {
221 encoder.add_original_shard(shard)?;
222 }
223
224 let result = encoder.encode()?;
225
226 for (index, output) in parity_shards_to_recover {
227 if output.len() != shard_byte_len {
228 return Err(ErasureCodingError::WrongParityShardByteLength {
229 expected: shard_byte_len,
230 actual: output.len(),
231 });
232 }
233 output.copy_from_slice(
234 result
235 .recovery(index)
236 .expect("Always corresponds to a missing parity shard; qed"),
237 );
238 }
239 }
240
241 Ok(())
242 }
243}