ab_blake3/
const_fn.rs

1//! `const fn` BLAKE3 functions.
2//!
3//! This module and submodules are copied with modifications from the official [`blake3`] crate and
4//! are expected to be removed once <https://github.com/BLAKE3-team/BLAKE3/pull/439> or similar
5//! lands upstream.
6
7mod hazmat;
8#[cfg(test)]
9mod tests;
10
11use crate::platform::{
12    MAX_SIMD_DEGREE, MAX_SIMD_DEGREE_OR_2, le_bytes_from_words_32, words_from_le_bytes_32,
13    words_from_le_bytes_64,
14};
15use crate::{
16    BLOCK_LEN, BlockBytes, CHUNK_END, CHUNK_LEN, CHUNK_START, CVBytes, CVWords, DERIVE_KEY_CONTEXT,
17    DERIVE_KEY_MATERIAL, IV, KEY_LEN, KEYED_HASH, OUT_LEN, PARENT, ROOT, portable,
18};
19use blake3::IncrementCounter;
20use core::mem::MaybeUninit;
21use core::slice;
22
23/// `Output` with `const fn` methods
24struct ConstOutput {
25    input_chaining_value: CVWords,
26    block: BlockBytes,
27    block_len: u8,
28    counter: u64,
29    flags: u8,
30}
31
32impl ConstOutput {
33    const fn chaining_value(&self) -> CVBytes {
34        let mut cv = self.input_chaining_value;
35        let block_words = words_from_le_bytes_64(&self.block);
36        portable::compress_in_place(
37            &mut cv,
38            &block_words,
39            self.block_len as u32,
40            self.counter,
41            self.flags as u32,
42        );
43        *le_bytes_from_words_32(&cv)
44    }
45
46    const fn root_hash(&self) -> [u8; OUT_LEN] {
47        debug_assert!(self.counter == 0);
48        let mut cv = self.input_chaining_value;
49        let block_words = words_from_le_bytes_64(&self.block);
50        portable::compress_in_place(
51            &mut cv,
52            &block_words,
53            self.block_len as u32,
54            0,
55            (self.flags | ROOT) as u32,
56        );
57        *le_bytes_from_words_32(&cv)
58    }
59}
60
61struct ConstChunkState {
62    cv: CVWords,
63    chunk_counter: u64,
64    buf: BlockBytes,
65    buf_len: u8,
66    blocks_compressed: u8,
67    flags: u8,
68}
69
70impl ConstChunkState {
71    const fn new(key: &CVWords, chunk_counter: u64, flags: u8) -> Self {
72        Self {
73            cv: *key,
74            chunk_counter,
75            buf: [0; BLOCK_LEN],
76            buf_len: 0,
77            blocks_compressed: 0,
78            flags,
79        }
80    }
81
82    const fn count(&self) -> usize {
83        BLOCK_LEN * self.blocks_compressed as usize + self.buf_len as usize
84    }
85
86    const fn fill_buf(&mut self, input: &mut &[u8]) {
87        let want = BLOCK_LEN - self.buf_len as usize;
88        let take = if want < input.len() {
89            want
90        } else {
91            input.len()
92        };
93        let output = self
94            .buf
95            .split_at_mut(self.buf_len as usize)
96            .1
97            .split_at_mut(take)
98            .0;
99        output.copy_from_slice(input.split_at(take).0);
100        self.buf_len += take as u8;
101        *input = input.split_at(take).1;
102    }
103
104    const fn start_flag(&self) -> u8 {
105        if self.blocks_compressed == 0 {
106            CHUNK_START
107        } else {
108            0
109        }
110    }
111
112    // Try to avoid buffering as much as possible by compressing directly from
113    // the input slice when full blocks are available.
114    const fn update(&mut self, mut input: &[u8]) -> &mut Self {
115        if self.buf_len > 0 {
116            self.fill_buf(&mut input);
117            if !input.is_empty() {
118                debug_assert!(self.buf_len as usize == BLOCK_LEN);
119                let block_flags = self.flags | self.start_flag(); // borrowck
120                let block_words = words_from_le_bytes_64(&self.buf);
121                portable::compress_in_place(
122                    &mut self.cv,
123                    &block_words,
124                    BLOCK_LEN as u32,
125                    self.chunk_counter,
126                    block_flags as u32,
127                );
128                self.buf_len = 0;
129                self.buf = [0; BLOCK_LEN];
130                self.blocks_compressed += 1;
131            }
132        }
133
134        while input.len() > BLOCK_LEN {
135            debug_assert!(self.buf_len == 0);
136            let block_flags = self.flags | self.start_flag(); // borrowck
137            let block = input
138                .first_chunk::<BLOCK_LEN>()
139                .expect("Interation only starts when there is at least `BLOCK_LEN` bytes; qed");
140            let block_words = words_from_le_bytes_64(block);
141            portable::compress_in_place(
142                &mut self.cv,
143                &block_words,
144                BLOCK_LEN as u32,
145                self.chunk_counter,
146                block_flags as u32,
147            );
148            self.blocks_compressed += 1;
149            input = input.split_at(BLOCK_LEN).1;
150        }
151
152        self.fill_buf(&mut input);
153        debug_assert!(input.is_empty());
154        debug_assert!(self.count() <= CHUNK_LEN);
155        self
156    }
157
158    const fn output(&self) -> ConstOutput {
159        let block_flags = self.flags | self.start_flag() | CHUNK_END;
160        ConstOutput {
161            input_chaining_value: self.cv,
162            block: self.buf,
163            block_len: self.buf_len,
164            counter: self.chunk_counter,
165            flags: block_flags,
166        }
167    }
168}
169
170// IMPLEMENTATION NOTE
171// ===================
172// The recursive function compress_subtree_wide(), implemented below, is the
173// basis of high-performance BLAKE3. We use it both for all-at-once hashing,
174// and for the incremental input with Hasher (though we have to be careful with
175// subtree boundaries in the incremental case). compress_subtree_wide() applies
176// several optimizations at the same time:
177// - Multithreading with Rayon.
178// - Parallel chunk hashing with SIMD.
179// - Parallel parent hashing with SIMD. Note that while SIMD chunk hashing maxes out at
180//   MAX_SIMD_DEGREE*CHUNK_LEN, parallel parent hashing continues to benefit from larger inputs,
181//   because more levels of the tree benefit can use full-width SIMD vectors for parent hashing.
182//   Without parallel parent hashing, we lose about 10% of overall throughput on AVX2 and AVX-512.
183
184// Use SIMD parallelism to hash up to MAX_SIMD_DEGREE chunks at the same time
185// on a single thread. Write out the chunk chaining values and return the
186// number of chunks hashed. These chunks are never the root and never empty;
187// those cases use a different codepath.
188const fn const_compress_chunks_parallel(
189    input: &[u8],
190    key: &CVWords,
191    chunk_counter: u64,
192    flags: u8,
193    out: &mut [u8],
194) -> usize {
195    debug_assert!(!input.is_empty(), "empty chunks below the root");
196    debug_assert!(input.len() <= MAX_SIMD_DEGREE * CHUNK_LEN);
197
198    let mut chunks = input;
199    let mut chunks_so_far = 0;
200    let mut chunks_array = [MaybeUninit::<&[u8; CHUNK_LEN]>::uninit(); MAX_SIMD_DEGREE];
201    while let Some(chunk) = chunks.first_chunk::<CHUNK_LEN>() {
202        chunks = chunks.split_at(CHUNK_LEN).1;
203        chunks_array[chunks_so_far].write(chunk);
204        chunks_so_far += 1;
205    }
206    portable::hash_many(
207        // SAFETY: Exactly `chunks_so_far` elements of `chunks_array` were initialized above
208        unsafe {
209            slice::from_raw_parts(
210                chunks_array.as_ptr().cast::<&[u8; CHUNK_LEN]>(),
211                chunks_so_far,
212            )
213        },
214        key,
215        chunk_counter,
216        IncrementCounter::Yes,
217        flags,
218        CHUNK_START,
219        CHUNK_END,
220        out,
221    );
222
223    // Hash the remaining partial chunk, if there is one. Note that the empty
224    // chunk (meaning the empty message) is a different codepath.
225    if !chunks.is_empty() {
226        let counter = chunk_counter + chunks_so_far as u64;
227        let mut chunk_state = ConstChunkState::new(key, counter, flags);
228        chunk_state.update(chunks);
229        let out = out
230            .split_at_mut(chunks_so_far * OUT_LEN)
231            .1
232            .split_at_mut(OUT_LEN)
233            .0;
234        let chaining_value = chunk_state.output().chaining_value();
235        out.copy_from_slice(&chaining_value);
236        chunks_so_far + 1
237    } else {
238        chunks_so_far
239    }
240}
241
242// Use SIMD parallelism to hash up to MAX_SIMD_DEGREE parents at the same time
243// on a single thread. Write out the parent chaining values and return the
244// number of parents hashed. (If there's an odd input chaining value left over,
245// return it as an additional output.) These parents are never the root and
246// never empty; those cases use a different codepath.
247const fn const_compress_parents_parallel(
248    child_chaining_values: &[u8],
249    key: &CVWords,
250    flags: u8,
251    out: &mut [u8],
252) -> usize {
253    debug_assert!(
254        child_chaining_values.len().is_multiple_of(OUT_LEN),
255        "wacky hash bytes"
256    );
257    let num_children = child_chaining_values.len() / OUT_LEN;
258    debug_assert!(num_children >= 2, "not enough children");
259    debug_assert!(num_children <= 2 * MAX_SIMD_DEGREE_OR_2, "too many");
260
261    let mut parents = child_chaining_values;
262    // Use MAX_SIMD_DEGREE_OR_2 rather than MAX_SIMD_DEGREE here, because of
263    // the requirements of compress_subtree_wide().
264    let mut parents_so_far = 0;
265    let mut parents_array = [MaybeUninit::<&BlockBytes>::uninit(); MAX_SIMD_DEGREE_OR_2];
266    while let Some(parent) = parents.first_chunk::<BLOCK_LEN>() {
267        parents = parents.split_at(BLOCK_LEN).1;
268        parents_array[parents_so_far].write(parent);
269        parents_so_far += 1;
270    }
271    portable::hash_many(
272        // SAFETY: Exactly `parents_so_far` elements of `parents_array` were initialized above
273        unsafe {
274            slice::from_raw_parts(parents_array.as_ptr().cast::<&BlockBytes>(), parents_so_far)
275        },
276        key,
277        0, // Parents always use counter 0.
278        IncrementCounter::No,
279        flags | PARENT,
280        0, // Parents have no start flags.
281        0, // Parents have no end flags.
282        out,
283    );
284
285    // If there's an odd child left over, it becomes an output.
286    if !parents.is_empty() {
287        let out = out
288            .split_at_mut(parents_so_far * OUT_LEN)
289            .1
290            .split_at_mut(OUT_LEN)
291            .0;
292        out.copy_from_slice(parents);
293        parents_so_far + 1
294    } else {
295        parents_so_far
296    }
297}
298
299// The wide helper function returns (writes out) an array of chaining values
300// and returns the length of that array. The number of chaining values returned
301// is the dynamically detected SIMD degree, at most MAX_SIMD_DEGREE. Or fewer,
302// if the input is shorter than that many chunks. The reason for maintaining a
303// wide array of chaining values going back up the tree, is to allow the
304// implementation to hash as many parents in parallel as possible.
305//
306// As a special case when the SIMD degree is 1, this function will still return
307// at least 2 outputs. This guarantees that this function doesn't perform the
308// root compression. (If it did, it would use the wrong flags, and also we
309// wouldn't be able to implement extendable output.) Note that this function is
310// not used when the whole input is only 1 chunk long; that's a different
311// codepath.
312//
313// Why not just have the caller split the input on the first update(), instead
314// of implementing this special rule? Because we don't want to limit SIMD or
315// multithreading parallelism for that update().
316const fn const_compress_subtree_wide(
317    input: &[u8],
318    key: &CVWords,
319    chunk_counter: u64,
320    flags: u8,
321    out: &mut [u8],
322) -> usize {
323    if input.len() <= CHUNK_LEN {
324        return const_compress_chunks_parallel(input, key, chunk_counter, flags, out);
325    }
326
327    let (left, right) = input.split_at(hazmat::left_subtree_len(input.len() as u64) as usize);
328    let right_chunk_counter = chunk_counter + (left.len() / CHUNK_LEN) as u64;
329
330    // Make space for the child outputs. Here we use MAX_SIMD_DEGREE_OR_2 to
331    // account for the special case of returning 2 outputs when the SIMD degree
332    // is 1.
333    let mut cv_array = [0; 2 * MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
334    let degree = if left.len() == CHUNK_LEN { 1 } else { 2 };
335    let (left_out, right_out) = cv_array.split_at_mut(degree * OUT_LEN);
336
337    // Recurse!
338    let left_n = const_compress_subtree_wide(left, key, chunk_counter, flags, left_out);
339    let right_n = const_compress_subtree_wide(right, key, right_chunk_counter, flags, right_out);
340
341    // The special case again. If simd_degree=1, then we'll have left_n=1 and
342    // right_n=1. Rather than compressing them into a single output, return
343    // them directly, to make sure we always have at least two outputs.
344    debug_assert!(left_n == degree);
345    debug_assert!(right_n >= 1 && right_n <= left_n);
346    if left_n == 1 {
347        out.split_at_mut(2 * OUT_LEN)
348            .0
349            .copy_from_slice(cv_array.split_at(2 * OUT_LEN).0);
350        return 2;
351    }
352
353    // Otherwise, do one layer of parent node compression.
354    let num_children = left_n + right_n;
355    const_compress_parents_parallel(cv_array.split_at(num_children * OUT_LEN).0, key, flags, out)
356}
357
358// Hash a subtree with compress_subtree_wide(), and then condense the resulting
359// list of chaining values down to a single parent node. Don't compress that
360// last parent node, however. Instead, return its message bytes (the
361// concatenated chaining values of its children). This is necessary when the
362// first call to update() supplies a complete subtree, because the topmost
363// parent node of that subtree could end up being the root. It's also necessary
364// for extended output in the general case.
365//
366// As with compress_subtree_wide(), this function is not used on inputs of 1
367// chunk or less. That's a different codepath.
368const fn const_compress_subtree_to_parent_node(
369    input: &[u8],
370    key: &CVWords,
371    chunk_counter: u64,
372    flags: u8,
373) -> BlockBytes {
374    debug_assert!(input.len() > CHUNK_LEN);
375    let mut cv_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
376    let mut num_cvs = const_compress_subtree_wide(input, key, chunk_counter, flags, &mut cv_array);
377    debug_assert!(num_cvs >= 2);
378
379    // If MAX_SIMD_DEGREE is greater than 2 and there's enough input,
380    // compress_subtree_wide() returns more than 2 chaining values. Condense
381    // them into 2 by forming parent nodes repeatedly.
382    let mut out_array = [0; MAX_SIMD_DEGREE_OR_2 * OUT_LEN / 2];
383    while num_cvs > 2 {
384        let cv_slice = cv_array.split_at(num_cvs * OUT_LEN).0;
385        num_cvs = const_compress_parents_parallel(cv_slice, key, flags, &mut out_array);
386        cv_array
387            .split_at_mut(num_cvs * OUT_LEN)
388            .0
389            .copy_from_slice(out_array.split_at(num_cvs * OUT_LEN).0);
390    }
391    *cv_array
392        .first_chunk::<BLOCK_LEN>()
393        .expect("`cv_array` is larger than `BLOCK_LEN`; qed")
394}
395
396// Hash a complete input all at once. Unlike compress_subtree_wide() and
397// compress_subtree_to_parent_node(), this function handles the 1 chunk case.
398const fn const_hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> ConstOutput {
399    // If the whole subtree is one chunk, hash it directly with a ChunkState.
400    if input.len() <= CHUNK_LEN {
401        return ConstChunkState::new(key, 0, flags).update(input).output();
402    }
403
404    // Otherwise construct a `ConstOutput` object from the parent node returned by
405    // compress_subtree_to_parent_node().
406    ConstOutput {
407        input_chaining_value: *key,
408        block: const_compress_subtree_to_parent_node(input, key, 0, flags),
409        block_len: BLOCK_LEN as u8,
410        counter: 0,
411        flags: flags | PARENT,
412    }
413}
414
415/// Hashing function like [`blake3::hash()`], but `const fn`
416pub const fn const_hash(input: &[u8]) -> [u8; OUT_LEN] {
417    const_hash_all_at_once(input, IV, 0).root_hash()
418}
419
420/// The keyed hash function like [`blake3::keyed_hash()`], but `const fn`
421pub const fn const_keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> [u8; OUT_LEN] {
422    let key_words = words_from_le_bytes_32(key);
423    const_hash_all_at_once(input, &key_words, KEYED_HASH).root_hash()
424}
425
426// The key derivation function like [`blake3::derive_key()`], but `const fn`
427pub const fn const_derive_key(context: &str, key_material: &[u8]) -> [u8; OUT_LEN] {
428    let context_key =
429        const_hash_all_at_once(context.as_bytes(), IV, DERIVE_KEY_CONTEXT).root_hash();
430    let context_key_words = words_from_le_bytes_32(&context_key);
431    const_hash_all_at_once(key_material, &context_key_words, DERIVE_KEY_MATERIAL).root_hash()
432}