1mod hazmat;
8#[cfg(test)]
9mod tests;
10
11use crate::platform::{
12 MAX_SIMD_DEGREE, MAX_SIMD_DEGREE_OR_2, le_bytes_from_words_32, words_from_le_bytes_32,
13 words_from_le_bytes_64,
14};
15use crate::{
16 BLOCK_LEN, BlockBytes, CHUNK_END, CHUNK_LEN, CHUNK_START, CVBytes, CVWords, DERIVE_KEY_CONTEXT,
17 DERIVE_KEY_MATERIAL, IV, KEY_LEN, KEYED_HASH, OUT_LEN, PARENT, ROOT, portable,
18};
19use blake3::IncrementCounter;
20use core::mem::MaybeUninit;
21use core::slice;
22
23struct ConstOutput {
25 input_chaining_value: CVWords,
26 block: BlockBytes,
27 block_len: u8,
28 counter: u64,
29 flags: u8,
30}
31
32impl ConstOutput {
33 const fn chaining_value(&self) -> CVBytes {
34 let mut cv = self.input_chaining_value;
35 let block_words = words_from_le_bytes_64(&self.block);
36 portable::compress_in_place(
37 &mut cv,
38 &block_words,
39 self.block_len as u32,
40 self.counter,
41 self.flags as u32,
42 );
43 *le_bytes_from_words_32(&cv)
44 }
45
46 const fn root_hash(&self) -> [u8; OUT_LEN] {
47 debug_assert!(self.counter == 0);
48 let mut cv = self.input_chaining_value;
49 let block_words = words_from_le_bytes_64(&self.block);
50 portable::compress_in_place(
51 &mut cv,
52 &block_words,
53 self.block_len as u32,
54 0,
55 (self.flags | ROOT) as u32,
56 );
57 *le_bytes_from_words_32(&cv)
58 }
59}
60
61struct ConstChunkState {
62 cv: CVWords,
63 chunk_counter: u64,
64 buf: BlockBytes,
65 buf_len: u8,
66 blocks_compressed: u8,
67 flags: u8,
68}
69
70impl ConstChunkState {
71 const fn new(key: &CVWords, chunk_counter: u64, flags: u8) -> Self {
72 Self {
73 cv: *key,
74 chunk_counter,
75 buf: [0; BLOCK_LEN],
76 buf_len: 0,
77 blocks_compressed: 0,
78 flags,
79 }
80 }
81
82 const fn count(&self) -> usize {
83 BLOCK_LEN * self.blocks_compressed as usize + self.buf_len as usize
84 }
85
86 const fn fill_buf(&mut self, input: &mut &[u8]) {
87 let want = BLOCK_LEN - self.buf_len as usize;
88 let take = if want < input.len() {
89 want
90 } else {
91 input.len()
92 };
93 let output = self
94 .buf
95 .split_at_mut(self.buf_len as usize)
96 .1
97 .split_at_mut(take)
98 .0;
99 output.copy_from_slice(input.split_at(take).0);
100 self.buf_len += take as u8;
101 *input = input.split_at(take).1;
102 }
103
104 const fn start_flag(&self) -> u8 {
105 if self.blocks_compressed == 0 {
106 CHUNK_START
107 } else {
108 0
109 }
110 }
111
112 const fn update(&mut self, mut input: &[u8]) -> &mut Self {
115 if self.buf_len > 0 {
116 self.fill_buf(&mut input);
117 if !input.is_empty() {
118 debug_assert!(self.buf_len as usize == BLOCK_LEN);
119 let block_flags = self.flags | self.start_flag(); let block_words = words_from_le_bytes_64(&self.buf);
121 portable::compress_in_place(
122 &mut self.cv,
123 &block_words,
124 BLOCK_LEN as u32,
125 self.chunk_counter,
126 block_flags as u32,
127 );
128 self.buf_len = 0;
129 self.buf = [0; BLOCK_LEN];
130 self.blocks_compressed += 1;
131 }
132 }
133
134 while input.len() > BLOCK_LEN {
135 debug_assert!(self.buf_len == 0);
136 let block_flags = self.flags | self.start_flag(); let block = input
138 .first_chunk::<BLOCK_LEN>()
139 .expect("Interation only starts when there is at least `BLOCK_LEN` bytes; qed");
140 let block_words = words_from_le_bytes_64(block);
141 portable::compress_in_place(
142 &mut self.cv,
143 &block_words,
144 BLOCK_LEN as u32,
145 self.chunk_counter,
146 block_flags as u32,
147 );
148 self.blocks_compressed += 1;
149 input = input.split_at(BLOCK_LEN).1;
150 }
151
152 self.fill_buf(&mut input);
153 debug_assert!(input.is_empty());
154 debug_assert!(self.count() <= CHUNK_LEN);
155 self
156 }
157
158 const fn output(&self) -> ConstOutput {
159 let block_flags = self.flags | self.start_flag() | CHUNK_END;
160 ConstOutput {
161 input_chaining_value: self.cv,
162 block: self.buf,
163 block_len: self.buf_len,
164 counter: self.chunk_counter,
165 flags: block_flags,
166 }
167 }
168}
169
170const fn const_compress_chunks_parallel(
189 input: &[u8],
190 key: &CVWords,
191 chunk_counter: u64,
192 flags: u8,
193 out: &mut [u8],
194) -> usize {
195 debug_assert!(!input.is_empty(), "empty chunks below the root");
196 debug_assert!(input.len() <= MAX_SIMD_DEGREE * CHUNK_LEN);
197
198 let mut chunks = input;
199 let mut chunks_so_far = 0;
200 let mut chunks_array = [MaybeUninit::<&[u8; CHUNK_LEN]>::uninit(); MAX_SIMD_DEGREE];
201 while let Some(chunk) = chunks.first_chunk::<CHUNK_LEN>() {
202 chunks = chunks.split_at(CHUNK_LEN).1;
203 chunks_array[chunks_so_far].write(chunk);
204 chunks_so_far += 1;
205 }
206 portable::hash_many(
207 unsafe {
209 slice::from_raw_parts(
210 chunks_array.as_ptr().cast::<&[u8; CHUNK_LEN]>(),
211 chunks_so_far,
212 )
213 },
214 key,
215 chunk_counter,
216 IncrementCounter::Yes,
217 flags,
218 CHUNK_START,
219 CHUNK_END,
220 out,
221 );
222
223 if !chunks.is_empty() {
226 let counter = chunk_counter + chunks_so_far as u64;
227 let mut chunk_state = ConstChunkState::new(key, counter, flags);
228 chunk_state.update(chunks);
229 let out = out
230 .split_at_mut(chunks_so_far * OUT_LEN)
231 .1
232 .split_at_mut(OUT_LEN)
233 .0;
234 let chaining_value = chunk_state.output().chaining_value();
235 out.copy_from_slice(&chaining_value);
236 chunks_so_far + 1
237 } else {
238 chunks_so_far
239 }
240}
241
242const fn const_compress_parents_parallel(
248 child_chaining_values: &[u8],
249 key: &CVWords,
250 flags: u8,
251 out: &mut [u8],
252) -> usize {
253 debug_assert!(
254 child_chaining_values.len().is_multiple_of(OUT_LEN),
255 "wacky hash bytes"
256 );
257 let num_children = child_chaining_values.len() / OUT_LEN;
258 debug_assert!(num_children >= 2, "not enough children");
259 debug_assert!(num_children <= 2 * MAX_SIMD_DEGREE_OR_2, "too many");
260
261 let mut parents = child_chaining_values;
262 let mut parents_so_far = 0;
265 let mut parents_array = [MaybeUninit::<&BlockBytes>::uninit(); MAX_SIMD_DEGREE_OR_2];
266 while let Some(parent) = parents.first_chunk::<BLOCK_LEN>() {
267 parents = parents.split_at(BLOCK_LEN).1;
268 parents_array[parents_so_far].write(parent);
269 parents_so_far += 1;
270 }
271 portable::hash_many(
272 unsafe {
274 slice::from_raw_parts(parents_array.as_ptr().cast::<&BlockBytes>(), parents_so_far)
275 },
276 key,
277 0, IncrementCounter::No,
279 flags | PARENT,
280 0, 0, out,
283 );
284
285 if !parents.is_empty() {
287 let out = out
288 .split_at_mut(parents_so_far * OUT_LEN)
289 .1
290 .split_at_mut(OUT_LEN)
291 .0;
292 out.copy_from_slice(parents);
293 parents_so_far + 1
294 } else {
295 parents_so_far
296 }
297}
298
299const fn const_compress_subtree_wide(
317 input: &[u8],
318 key: &CVWords,
319 chunk_counter: u64,
320 flags: u8,
321 out: &mut [u8],
322) -> usize {
323 if input.len() <= CHUNK_LEN {
324 return const_compress_chunks_parallel(input, key, chunk_counter, flags, out);
325 }
326
327 let (left, right) = input.split_at(hazmat::left_subtree_len(input.len() as u64) as usize);
328 let right_chunk_counter = chunk_counter + (left.len() / CHUNK_LEN) as u64;
329
330 let mut cv_array = [0; 2 * MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
334 let degree = if left.len() == CHUNK_LEN { 1 } else { 2 };
335 let (left_out, right_out) = cv_array.split_at_mut(degree * OUT_LEN);
336
337 let left_n = const_compress_subtree_wide(left, key, chunk_counter, flags, left_out);
339 let right_n = const_compress_subtree_wide(right, key, right_chunk_counter, flags, right_out);
340
341 debug_assert!(left_n == degree);
345 debug_assert!(right_n >= 1 && right_n <= left_n);
346 if left_n == 1 {
347 out.split_at_mut(2 * OUT_LEN)
348 .0
349 .copy_from_slice(cv_array.split_at(2 * OUT_LEN).0);
350 return 2;
351 }
352
353 let num_children = left_n + right_n;
355 const_compress_parents_parallel(cv_array.split_at(num_children * OUT_LEN).0, key, flags, out)
356}
357
358const fn const_compress_subtree_to_parent_node(
369 input: &[u8],
370 key: &CVWords,
371 chunk_counter: u64,
372 flags: u8,
373) -> BlockBytes {
374 debug_assert!(input.len() > CHUNK_LEN);
375 let mut cv_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
376 let mut num_cvs = const_compress_subtree_wide(input, key, chunk_counter, flags, &mut cv_array);
377 debug_assert!(num_cvs >= 2);
378
379 let mut out_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN / 2];
383 while num_cvs > 2 {
384 let cv_slice = cv_array.split_at(num_cvs * OUT_LEN).0;
385 num_cvs = const_compress_parents_parallel(cv_slice, key, flags, &mut out_array);
386 cv_array
387 .split_at_mut(num_cvs * OUT_LEN)
388 .0
389 .copy_from_slice(out_array.split_at(num_cvs * OUT_LEN).0);
390 }
391 *cv_array
392 .first_chunk::<BLOCK_LEN>()
393 .expect("`cv_array` is larger than `BLOCK_LEN`; qed")
394}
395
396const fn const_hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> ConstOutput {
399 if input.len() <= CHUNK_LEN {
401 return ConstChunkState::new(key, 0, flags).update(input).output();
402 }
403
404 ConstOutput {
407 input_chaining_value: *key,
408 block: const_compress_subtree_to_parent_node(input, key, 0, flags),
409 block_len: BLOCK_LEN as u8,
410 counter: 0,
411 flags: flags | PARENT,
412 }
413}
414
415pub const fn const_hash(input: &[u8]) -> [u8; OUT_LEN] {
417 const_hash_all_at_once(input, IV, 0).root_hash()
418}
419
420pub const fn const_keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> [u8; OUT_LEN] {
422 let key_words = words_from_le_bytes_32(key);
423 const_hash_all_at_once(input, &key_words, KEYED_HASH).root_hash()
424}
425
426pub const fn const_derive_key(context: &str, key_material: &[u8]) -> [u8; OUT_LEN] {
428 let context_key =
429 const_hash_all_at_once(context.as_bytes(), IV, DERIVE_KEY_CONTEXT).root_hash();
430 let context_key_words = words_from_le_bytes_32(&context_key);
431 const_hash_all_at_once(key_material, &context_key_words, DERIVE_KEY_MATERIAL).root_hash()
432}