1mod hazmat;
8#[cfg(test)]
9mod tests;
10
11use crate::platform::{
12 le_bytes_from_words_32, words_from_le_bytes_32, words_from_le_bytes_64, MAX_SIMD_DEGREE,
13 MAX_SIMD_DEGREE_OR_2,
14};
15use crate::portable::IncrementCounter;
16use crate::{
17 portable, BlockBytes, CVBytes, CVWords, BLOCK_LEN, CHUNK_END, CHUNK_LEN, CHUNK_START,
18 DERIVE_KEY_CONTEXT, DERIVE_KEY_MATERIAL, IV, KEYED_HASH, KEY_LEN, OUT_LEN, PARENT, ROOT,
19};
20use core::mem::MaybeUninit;
21use core::slice;
22
23struct ConstOutput {
25 input_chaining_value: CVWords,
26 block: BlockBytes,
27 block_len: u8,
28 counter: u64,
29 flags: u8,
30}
31
32impl ConstOutput {
33 const fn chaining_value(&self) -> CVBytes {
34 let mut cv = self.input_chaining_value;
35 let block_words = words_from_le_bytes_64(&self.block);
36 portable::compress_in_place(
37 &mut cv,
38 &block_words,
39 self.block_len as u32,
40 self.counter,
41 self.flags as u32,
42 );
43 *le_bytes_from_words_32(&cv)
44 }
45
46 const fn root_hash(&self) -> [u8; OUT_LEN] {
47 debug_assert!(self.counter == 0);
48 let mut cv = self.input_chaining_value;
49 let block_words = words_from_le_bytes_64(&self.block);
50 portable::compress_in_place(
51 &mut cv,
52 &block_words,
53 self.block_len as u32,
54 0,
55 (self.flags | ROOT) as u32,
56 );
57 *le_bytes_from_words_32(&cv)
58 }
59}
60
61struct ConstChunkState {
62 cv: CVWords,
63 chunk_counter: u64,
64 buf: BlockBytes,
65 buf_len: u8,
66 blocks_compressed: u8,
67 flags: u8,
68}
69
70impl ConstChunkState {
71 const fn new(key: &CVWords, chunk_counter: u64, flags: u8) -> Self {
72 Self {
73 cv: *key,
74 chunk_counter,
75 buf: [0; BLOCK_LEN],
76 buf_len: 0,
77 blocks_compressed: 0,
78 flags,
79 }
80 }
81
82 const fn count(&self) -> usize {
83 BLOCK_LEN * self.blocks_compressed as usize + self.buf_len as usize
84 }
85
86 const fn fill_buf(&mut self, input: &mut &[u8]) {
87 let want = BLOCK_LEN - self.buf_len as usize;
88 let take = if want < input.len() {
89 want
90 } else {
91 input.len()
92 };
93 let output = self
94 .buf
95 .split_at_mut(self.buf_len as usize)
96 .1
97 .split_at_mut(take)
98 .0;
99 output.copy_from_slice(input.split_at(take).0);
100 self.buf_len += take as u8;
101 *input = input.split_at(take).1;
102 }
103
104 const fn start_flag(&self) -> u8 {
105 if self.blocks_compressed == 0 {
106 CHUNK_START
107 } else {
108 0
109 }
110 }
111
112 const fn update(&mut self, mut input: &[u8]) -> &mut Self {
115 if self.buf_len > 0 {
116 self.fill_buf(&mut input);
117 if !input.is_empty() {
118 debug_assert!(self.buf_len as usize == BLOCK_LEN);
119 let block_flags = self.flags | self.start_flag(); let block_words = words_from_le_bytes_64(&self.buf);
121 portable::compress_in_place(
122 &mut self.cv,
123 &block_words,
124 BLOCK_LEN as u32,
125 self.chunk_counter,
126 block_flags as u32,
127 );
128 self.buf_len = 0;
129 self.buf = [0; BLOCK_LEN];
130 self.blocks_compressed += 1;
131 }
132 }
133
134 while input.len() > BLOCK_LEN {
135 debug_assert!(self.buf_len == 0);
136 let block_flags = self.flags | self.start_flag(); let block = input
138 .first_chunk::<BLOCK_LEN>()
139 .expect("Interation only starts when there is at least `BLOCK_LEN` bytes; qed");
140 let block_words = words_from_le_bytes_64(block);
141 portable::compress_in_place(
142 &mut self.cv,
143 &block_words,
144 BLOCK_LEN as u32,
145 self.chunk_counter,
146 block_flags as u32,
147 );
148 self.blocks_compressed += 1;
149 input = input.split_at(BLOCK_LEN).1;
150 }
151
152 self.fill_buf(&mut input);
153 debug_assert!(input.is_empty());
154 debug_assert!(self.count() <= CHUNK_LEN);
155 self
156 }
157
158 const fn output(&self) -> ConstOutput {
159 let block_flags = self.flags | self.start_flag() | CHUNK_END;
160 ConstOutput {
161 input_chaining_value: self.cv,
162 block: self.buf,
163 block_len: self.buf_len,
164 counter: self.chunk_counter,
165 flags: block_flags,
166 }
167 }
168}
169
170const fn const_compress_chunks_parallel(
190 input: &[u8],
191 key: &CVWords,
192 chunk_counter: u64,
193 flags: u8,
194 out: &mut [u8],
195) -> usize {
196 debug_assert!(!input.is_empty(), "empty chunks below the root");
197 debug_assert!(input.len() <= MAX_SIMD_DEGREE * CHUNK_LEN);
198
199 let mut chunks = input;
200 let mut chunks_so_far = 0;
201 let mut chunks_array = [MaybeUninit::<&[u8; CHUNK_LEN]>::uninit(); MAX_SIMD_DEGREE];
202 while let Some(chunk) = chunks.first_chunk::<CHUNK_LEN>() {
203 chunks = chunks.split_at(CHUNK_LEN).1;
204 chunks_array[chunks_so_far].write(chunk);
205 chunks_so_far += 1;
206 }
207 portable::hash_many(
208 unsafe {
210 slice::from_raw_parts(
211 chunks_array.as_ptr().cast::<&[u8; CHUNK_LEN]>(),
212 chunks_so_far,
213 )
214 },
215 key,
216 chunk_counter,
217 IncrementCounter::Yes,
218 flags,
219 CHUNK_START,
220 CHUNK_END,
221 out,
222 );
223
224 if !chunks.is_empty() {
227 let counter = chunk_counter + chunks_so_far as u64;
228 let mut chunk_state = ConstChunkState::new(key, counter, flags);
229 chunk_state.update(chunks);
230 let out = out
231 .split_at_mut(chunks_so_far * OUT_LEN)
232 .1
233 .split_at_mut(OUT_LEN)
234 .0;
235 let chaining_value = chunk_state.output().chaining_value();
236 out.copy_from_slice(&chaining_value);
237 chunks_so_far + 1
238 } else {
239 chunks_so_far
240 }
241}
242
243const fn const_compress_parents_parallel(
249 child_chaining_values: &[u8],
250 key: &CVWords,
251 flags: u8,
252 out: &mut [u8],
253) -> usize {
254 debug_assert!(
255 child_chaining_values.len() % OUT_LEN == 0,
256 "wacky hash bytes"
257 );
258 let num_children = child_chaining_values.len() / OUT_LEN;
259 debug_assert!(num_children >= 2, "not enough children");
260 debug_assert!(num_children <= 2 * MAX_SIMD_DEGREE_OR_2, "too many");
261
262 let mut parents = child_chaining_values;
263 let mut parents_so_far = 0;
266 let mut parents_array = [MaybeUninit::<&BlockBytes>::uninit(); MAX_SIMD_DEGREE_OR_2];
267 while let Some(parent) = parents.first_chunk::<BLOCK_LEN>() {
268 parents = parents.split_at(BLOCK_LEN).1;
269 parents_array[parents_so_far].write(parent);
270 parents_so_far += 1;
271 }
272 portable::hash_many(
273 unsafe {
275 slice::from_raw_parts(parents_array.as_ptr().cast::<&BlockBytes>(), parents_so_far)
276 },
277 key,
278 0, IncrementCounter::No,
280 flags | PARENT,
281 0, 0, out,
284 );
285
286 if !parents.is_empty() {
288 let out = out
289 .split_at_mut(parents_so_far * OUT_LEN)
290 .1
291 .split_at_mut(OUT_LEN)
292 .0;
293 out.copy_from_slice(parents);
294 parents_so_far + 1
295 } else {
296 parents_so_far
297 }
298}
299
300const fn const_compress_subtree_wide(
318 input: &[u8],
319 key: &CVWords,
320 chunk_counter: u64,
321 flags: u8,
322 out: &mut [u8],
323) -> usize {
324 if input.len() <= CHUNK_LEN {
325 return const_compress_chunks_parallel(input, key, chunk_counter, flags, out);
326 }
327
328 let (left, right) = input.split_at(hazmat::left_subtree_len(input.len() as u64) as usize);
329 let right_chunk_counter = chunk_counter + (left.len() / CHUNK_LEN) as u64;
330
331 let mut cv_array = [0; 2 * MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
335 let degree = if left.len() == CHUNK_LEN { 1 } else { 2 };
336 let (left_out, right_out) = cv_array.split_at_mut(degree * OUT_LEN);
337
338 let left_n = const_compress_subtree_wide(left, key, chunk_counter, flags, left_out);
340 let right_n = const_compress_subtree_wide(right, key, right_chunk_counter, flags, right_out);
341
342 debug_assert!(left_n == degree);
346 debug_assert!(right_n >= 1 && right_n <= left_n);
347 if left_n == 1 {
348 out.split_at_mut(2 * OUT_LEN)
349 .0
350 .copy_from_slice(cv_array.split_at(2 * OUT_LEN).0);
351 return 2;
352 }
353
354 let num_children = left_n + right_n;
356 const_compress_parents_parallel(cv_array.split_at(num_children * OUT_LEN).0, key, flags, out)
357}
358
359const fn const_compress_subtree_to_parent_node(
370 input: &[u8],
371 key: &CVWords,
372 chunk_counter: u64,
373 flags: u8,
374) -> BlockBytes {
375 debug_assert!(input.len() > CHUNK_LEN);
376 let mut cv_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
377 let mut num_cvs = const_compress_subtree_wide(input, key, chunk_counter, flags, &mut cv_array);
378 debug_assert!(num_cvs >= 2);
379
380 let mut out_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN / 2];
384 while num_cvs > 2 {
385 let cv_slice = cv_array.split_at(num_cvs * OUT_LEN).0;
386 num_cvs = const_compress_parents_parallel(cv_slice, key, flags, &mut out_array);
387 cv_array
388 .split_at_mut(num_cvs * OUT_LEN)
389 .0
390 .copy_from_slice(out_array.split_at(num_cvs * OUT_LEN).0);
391 }
392 *cv_array
393 .first_chunk::<BLOCK_LEN>()
394 .expect("`cv_array` is larger than `BLOCK_LEN`; qed")
395}
396
397const fn const_hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> ConstOutput {
400 if input.len() <= CHUNK_LEN {
402 return ConstChunkState::new(key, 0, flags).update(input).output();
403 }
404
405 ConstOutput {
408 input_chaining_value: *key,
409 block: const_compress_subtree_to_parent_node(input, key, 0, flags),
410 block_len: BLOCK_LEN as u8,
411 counter: 0,
412 flags: flags | PARENT,
413 }
414}
415
416pub const fn const_hash(input: &[u8]) -> [u8; OUT_LEN] {
418 const_hash_all_at_once(input, IV, 0).root_hash()
419}
420
421pub const fn const_keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> [u8; OUT_LEN] {
423 let key_words = words_from_le_bytes_32(key);
424 const_hash_all_at_once(input, &key_words, KEYED_HASH).root_hash()
425}
426
427pub const fn const_derive_key(context: &str, key_material: &[u8]) -> [u8; OUT_LEN] {
429 let context_key =
430 const_hash_all_at_once(context.as_bytes(), IV, DERIVE_KEY_CONTEXT).root_hash();
431 let context_key_words = words_from_le_bytes_32(&context_key);
432 const_hash_all_at_once(key_material, &context_key_words, DERIVE_KEY_MATERIAL).root_hash()
433}