ab_blake3/
const_fn.rs

1//! `const fn` BLAKE3 functions.
2//!
3//! This module and submodules are copied with modifications from the official [`blake3`] crate and
4//! is expected to be removed once <https://github.com/BLAKE3-team/BLAKE3/pull/439> or similar lands
5//! upstream.
6
7mod hazmat;
8#[cfg(test)]
9mod tests;
10
11use crate::platform::{
12    le_bytes_from_words_32, words_from_le_bytes_32, words_from_le_bytes_64, MAX_SIMD_DEGREE,
13    MAX_SIMD_DEGREE_OR_2,
14};
15use crate::portable::IncrementCounter;
16use crate::{
17    portable, BlockBytes, CVBytes, CVWords, BLOCK_LEN, CHUNK_END, CHUNK_LEN, CHUNK_START,
18    DERIVE_KEY_CONTEXT, DERIVE_KEY_MATERIAL, IV, KEYED_HASH, KEY_LEN, OUT_LEN, PARENT, ROOT,
19};
20use core::mem::MaybeUninit;
21use core::slice;
22
23/// `Output` with `const fn` methods
24struct ConstOutput {
25    input_chaining_value: CVWords,
26    block: BlockBytes,
27    block_len: u8,
28    counter: u64,
29    flags: u8,
30}
31
32impl ConstOutput {
33    const fn chaining_value(&self) -> CVBytes {
34        let mut cv = self.input_chaining_value;
35        let block_words = words_from_le_bytes_64(&self.block);
36        portable::compress_in_place(
37            &mut cv,
38            &block_words,
39            self.block_len as u32,
40            self.counter,
41            self.flags as u32,
42        );
43        *le_bytes_from_words_32(&cv)
44    }
45
46    const fn root_hash(&self) -> [u8; OUT_LEN] {
47        debug_assert!(self.counter == 0);
48        let mut cv = self.input_chaining_value;
49        let block_words = words_from_le_bytes_64(&self.block);
50        portable::compress_in_place(
51            &mut cv,
52            &block_words,
53            self.block_len as u32,
54            0,
55            (self.flags | ROOT) as u32,
56        );
57        *le_bytes_from_words_32(&cv)
58    }
59}
60
61struct ConstChunkState {
62    cv: CVWords,
63    chunk_counter: u64,
64    buf: BlockBytes,
65    buf_len: u8,
66    blocks_compressed: u8,
67    flags: u8,
68}
69
70impl ConstChunkState {
71    const fn new(key: &CVWords, chunk_counter: u64, flags: u8) -> Self {
72        Self {
73            cv: *key,
74            chunk_counter,
75            buf: [0; BLOCK_LEN],
76            buf_len: 0,
77            blocks_compressed: 0,
78            flags,
79        }
80    }
81
82    const fn count(&self) -> usize {
83        BLOCK_LEN * self.blocks_compressed as usize + self.buf_len as usize
84    }
85
86    const fn fill_buf(&mut self, input: &mut &[u8]) {
87        let want = BLOCK_LEN - self.buf_len as usize;
88        let take = if want < input.len() {
89            want
90        } else {
91            input.len()
92        };
93        let output = self
94            .buf
95            .split_at_mut(self.buf_len as usize)
96            .1
97            .split_at_mut(take)
98            .0;
99        output.copy_from_slice(input.split_at(take).0);
100        self.buf_len += take as u8;
101        *input = input.split_at(take).1;
102    }
103
104    const fn start_flag(&self) -> u8 {
105        if self.blocks_compressed == 0 {
106            CHUNK_START
107        } else {
108            0
109        }
110    }
111
112    // Try to avoid buffering as much as possible, by compressing directly from
113    // the input slice when full blocks are available.
114    const fn update(&mut self, mut input: &[u8]) -> &mut Self {
115        if self.buf_len > 0 {
116            self.fill_buf(&mut input);
117            if !input.is_empty() {
118                debug_assert!(self.buf_len as usize == BLOCK_LEN);
119                let block_flags = self.flags | self.start_flag(); // borrowck
120                let block_words = words_from_le_bytes_64(&self.buf);
121                portable::compress_in_place(
122                    &mut self.cv,
123                    &block_words,
124                    BLOCK_LEN as u32,
125                    self.chunk_counter,
126                    block_flags as u32,
127                );
128                self.buf_len = 0;
129                self.buf = [0; BLOCK_LEN];
130                self.blocks_compressed += 1;
131            }
132        }
133
134        while input.len() > BLOCK_LEN {
135            debug_assert!(self.buf_len == 0);
136            let block_flags = self.flags | self.start_flag(); // borrowck
137            let block = input
138                .first_chunk::<BLOCK_LEN>()
139                .expect("Interation only starts when there is at least `BLOCK_LEN` bytes; qed");
140            let block_words = words_from_le_bytes_64(block);
141            portable::compress_in_place(
142                &mut self.cv,
143                &block_words,
144                BLOCK_LEN as u32,
145                self.chunk_counter,
146                block_flags as u32,
147            );
148            self.blocks_compressed += 1;
149            input = input.split_at(BLOCK_LEN).1;
150        }
151
152        self.fill_buf(&mut input);
153        debug_assert!(input.is_empty());
154        debug_assert!(self.count() <= CHUNK_LEN);
155        self
156    }
157
158    const fn output(&self) -> ConstOutput {
159        let block_flags = self.flags | self.start_flag() | CHUNK_END;
160        ConstOutput {
161            input_chaining_value: self.cv,
162            block: self.buf,
163            block_len: self.buf_len,
164            counter: self.chunk_counter,
165            flags: block_flags,
166        }
167    }
168}
169
170// IMPLEMENTATION NOTE
171// ===================
172// The recursive function compress_subtree_wide(), implemented below, is the
173// basis of high-performance BLAKE3. We use it both for all-at-once hashing,
174// and for the incremental input with Hasher (though we have to be careful with
175// subtree boundaries in the incremental case). compress_subtree_wide() applies
176// several optimizations at the same time:
177// - Multithreading with Rayon.
178// - Parallel chunk hashing with SIMD.
179// - Parallel parent hashing with SIMD. Note that while SIMD chunk hashing
180//   maxes out at MAX_SIMD_DEGREE*CHUNK_LEN, parallel parent hashing continues
181//   to benefit from larger inputs, because more levels of the tree benefit can
182//   use full-width SIMD vectors for parent hashing. Without parallel parent
183//   hashing, we lose about 10% of overall throughput on AVX2 and AVX-512.
184
185// Use SIMD parallelism to hash up to MAX_SIMD_DEGREE chunks at the same time
186// on a single thread. Write out the chunk chaining values and return the
187// number of chunks hashed. These chunks are never the root and never empty;
188// those cases use a different codepath.
189const fn const_compress_chunks_parallel(
190    input: &[u8],
191    key: &CVWords,
192    chunk_counter: u64,
193    flags: u8,
194    out: &mut [u8],
195) -> usize {
196    debug_assert!(!input.is_empty(), "empty chunks below the root");
197    debug_assert!(input.len() <= MAX_SIMD_DEGREE * CHUNK_LEN);
198
199    let mut chunks = input;
200    let mut chunks_so_far = 0;
201    let mut chunks_array = [MaybeUninit::<&[u8; CHUNK_LEN]>::uninit(); MAX_SIMD_DEGREE];
202    while let Some(chunk) = chunks.first_chunk::<CHUNK_LEN>() {
203        chunks = chunks.split_at(CHUNK_LEN).1;
204        chunks_array[chunks_so_far].write(chunk);
205        chunks_so_far += 1;
206    }
207    portable::hash_many(
208        // SAFETY: Exactly `chunks_so_far` elements of `chunks_array` were initialized above
209        unsafe {
210            slice::from_raw_parts(
211                chunks_array.as_ptr().cast::<&[u8; CHUNK_LEN]>(),
212                chunks_so_far,
213            )
214        },
215        key,
216        chunk_counter,
217        IncrementCounter::Yes,
218        flags,
219        CHUNK_START,
220        CHUNK_END,
221        out,
222    );
223
224    // Hash the remaining partial chunk, if there is one. Note that the empty
225    // chunk (meaning the empty message) is a different codepath.
226    if !chunks.is_empty() {
227        let counter = chunk_counter + chunks_so_far as u64;
228        let mut chunk_state = ConstChunkState::new(key, counter, flags);
229        chunk_state.update(chunks);
230        let out = out
231            .split_at_mut(chunks_so_far * OUT_LEN)
232            .1
233            .split_at_mut(OUT_LEN)
234            .0;
235        let chaining_value = chunk_state.output().chaining_value();
236        out.copy_from_slice(&chaining_value);
237        chunks_so_far + 1
238    } else {
239        chunks_so_far
240    }
241}
242
243// Use SIMD parallelism to hash up to MAX_SIMD_DEGREE parents at the same time
244// on a single thread. Write out the parent chaining values and return the
245// number of parents hashed. (If there's an odd input chaining value left over,
246// return it as an additional output.) These parents are never the root and
247// never empty; those cases use a different codepath.
248const fn const_compress_parents_parallel(
249    child_chaining_values: &[u8],
250    key: &CVWords,
251    flags: u8,
252    out: &mut [u8],
253) -> usize {
254    debug_assert!(
255        child_chaining_values.len() % OUT_LEN == 0,
256        "wacky hash bytes"
257    );
258    let num_children = child_chaining_values.len() / OUT_LEN;
259    debug_assert!(num_children >= 2, "not enough children");
260    debug_assert!(num_children <= 2 * MAX_SIMD_DEGREE_OR_2, "too many");
261
262    let mut parents = child_chaining_values;
263    // Use MAX_SIMD_DEGREE_OR_2 rather than MAX_SIMD_DEGREE here, because of
264    // the requirements of compress_subtree_wide().
265    let mut parents_so_far = 0;
266    let mut parents_array = [MaybeUninit::<&BlockBytes>::uninit(); MAX_SIMD_DEGREE_OR_2];
267    while let Some(parent) = parents.first_chunk::<BLOCK_LEN>() {
268        parents = parents.split_at(BLOCK_LEN).1;
269        parents_array[parents_so_far].write(parent);
270        parents_so_far += 1;
271    }
272    portable::hash_many(
273        // SAFETY: Exactly `parents_so_far` elements of `parents_array` were initialized above
274        unsafe {
275            slice::from_raw_parts(parents_array.as_ptr().cast::<&BlockBytes>(), parents_so_far)
276        },
277        key,
278        0, // Parents always use counter 0.
279        IncrementCounter::No,
280        flags | PARENT,
281        0, // Parents have no start flags.
282        0, // Parents have no end flags.
283        out,
284    );
285
286    // If there's an odd child left over, it becomes an output.
287    if !parents.is_empty() {
288        let out = out
289            .split_at_mut(parents_so_far * OUT_LEN)
290            .1
291            .split_at_mut(OUT_LEN)
292            .0;
293        out.copy_from_slice(parents);
294        parents_so_far + 1
295    } else {
296        parents_so_far
297    }
298}
299
300// The wide helper function returns (writes out) an array of chaining values
301// and returns the length of that array. The number of chaining values returned
302// is the dynamically detected SIMD degree, at most MAX_SIMD_DEGREE. Or fewer,
303// if the input is shorter than that many chunks. The reason for maintaining a
304// wide array of chaining values going back up the tree, is to allow the
305// implementation to hash as many parents in parallel as possible.
306//
307// As a special case when the SIMD degree is 1, this function will still return
308// at least 2 outputs. This guarantees that this function doesn't perform the
309// root compression. (If it did, it would use the wrong flags, and also we
310// wouldn't be able to implement extendable output.) Note that this function is
311// not used when the whole input is only 1 chunk long; that's a different
312// codepath.
313//
314// Why not just have the caller split the input on the first update(), instead
315// of implementing this special rule? Because we don't want to limit SIMD or
316// multithreading parallelism for that update().
317const fn const_compress_subtree_wide(
318    input: &[u8],
319    key: &CVWords,
320    chunk_counter: u64,
321    flags: u8,
322    out: &mut [u8],
323) -> usize {
324    if input.len() <= CHUNK_LEN {
325        return const_compress_chunks_parallel(input, key, chunk_counter, flags, out);
326    }
327
328    let (left, right) = input.split_at(hazmat::left_subtree_len(input.len() as u64) as usize);
329    let right_chunk_counter = chunk_counter + (left.len() / CHUNK_LEN) as u64;
330
331    // Make space for the child outputs. Here we use MAX_SIMD_DEGREE_OR_2 to
332    // account for the special case of returning 2 outputs when the SIMD degree
333    // is 1.
334    let mut cv_array = [0; 2 * MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
335    let degree = if left.len() == CHUNK_LEN { 1 } else { 2 };
336    let (left_out, right_out) = cv_array.split_at_mut(degree * OUT_LEN);
337
338    // Recurse!
339    let left_n = const_compress_subtree_wide(left, key, chunk_counter, flags, left_out);
340    let right_n = const_compress_subtree_wide(right, key, right_chunk_counter, flags, right_out);
341
342    // The special case again. If simd_degree=1, then we'll have left_n=1 and
343    // right_n=1. Rather than compressing them into a single output, return
344    // them directly, to make sure we always have at least two outputs.
345    debug_assert!(left_n == degree);
346    debug_assert!(right_n >= 1 && right_n <= left_n);
347    if left_n == 1 {
348        out.split_at_mut(2 * OUT_LEN)
349            .0
350            .copy_from_slice(cv_array.split_at(2 * OUT_LEN).0);
351        return 2;
352    }
353
354    // Otherwise, do one layer of parent node compression.
355    let num_children = left_n + right_n;
356    const_compress_parents_parallel(cv_array.split_at(num_children * OUT_LEN).0, key, flags, out)
357}
358
359// Hash a subtree with compress_subtree_wide(), and then condense the resulting
360// list of chaining values down to a single parent node. Don't compress that
361// last parent node, however. Instead, return its message bytes (the
362// concatenated chaining values of its children). This is necessary when the
363// first call to update() supplies a complete subtree, because the topmost
364// parent node of that subtree could end up being the root. It's also necessary
365// for extended output in the general case.
366//
367// As with compress_subtree_wide(), this function is not used on inputs of 1
368// chunk or less. That's a different codepath.
369const fn const_compress_subtree_to_parent_node(
370    input: &[u8],
371    key: &CVWords,
372    chunk_counter: u64,
373    flags: u8,
374) -> BlockBytes {
375    debug_assert!(input.len() > CHUNK_LEN);
376    let mut cv_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
377    let mut num_cvs = const_compress_subtree_wide(input, key, chunk_counter, flags, &mut cv_array);
378    debug_assert!(num_cvs >= 2);
379
380    // If MAX_SIMD_DEGREE is greater than 2 and there's enough input,
381    // compress_subtree_wide() returns more than 2 chaining values. Condense
382    // them into 2 by forming parent nodes repeatedly.
383    let mut out_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN / 2];
384    while num_cvs > 2 {
385        let cv_slice = cv_array.split_at(num_cvs * OUT_LEN).0;
386        num_cvs = const_compress_parents_parallel(cv_slice, key, flags, &mut out_array);
387        cv_array
388            .split_at_mut(num_cvs * OUT_LEN)
389            .0
390            .copy_from_slice(out_array.split_at(num_cvs * OUT_LEN).0);
391    }
392    *cv_array
393        .first_chunk::<BLOCK_LEN>()
394        .expect("`cv_array` is larger than `BLOCK_LEN`; qed")
395}
396
397// Hash a complete input all at once. Unlike compress_subtree_wide() and
398// compress_subtree_to_parent_node(), this function handles the 1 chunk case.
399const fn const_hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> ConstOutput {
400    // If the whole subtree is one chunk, hash it directly with a ChunkState.
401    if input.len() <= CHUNK_LEN {
402        return ConstChunkState::new(key, 0, flags).update(input).output();
403    }
404
405    // Otherwise construct an Output object from the parent node returned by
406    // compress_subtree_to_parent_node().
407    ConstOutput {
408        input_chaining_value: *key,
409        block: const_compress_subtree_to_parent_node(input, key, 0, flags),
410        block_len: BLOCK_LEN as u8,
411        counter: 0,
412        flags: flags | PARENT,
413    }
414}
415
416/// Hashing function like [`blake3::hash()`], but `const fn`
417pub const fn const_hash(input: &[u8]) -> [u8; OUT_LEN] {
418    const_hash_all_at_once(input, IV, 0).root_hash()
419}
420
421/// The keyed hash function like [`blake3::keyed_hash()`], but `const fn`
422pub const fn const_keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> [u8; OUT_LEN] {
423    let key_words = words_from_le_bytes_32(key);
424    const_hash_all_at_once(input, &key_words, KEYED_HASH).root_hash()
425}
426
427// The key derivation function like [`blake3::derive_key()`], but `const fn`
428pub const fn const_derive_key(context: &str, key_material: &[u8]) -> [u8; OUT_LEN] {
429    let context_key =
430        const_hash_all_at_once(context.as_bytes(), IV, DERIVE_KEY_CONTEXT).root_hash();
431    let context_key_words = words_from_le_bytes_32(&context_key);
432    const_hash_all_at_once(key_material, &context_key_words, DERIVE_KEY_MATERIAL).root_hash()
433}