ab_chacha8/
lib.rs

1//! Small GPU-friendly software implementation of ChaCha8
2
3#![no_std]
4#![feature(array_chunks)]
5
6#[cfg(test)]
7mod tests;
8
9/// A single ChaCha8 block
10pub type ChaCha8Block = [u32; 16];
11
12/// Convert block to bytes
13#[inline(always)]
14#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
15pub fn block_to_bytes(block: &ChaCha8Block) -> [u8; 64] {
16    // SAFETY: Same size and no alignment requirements
17    unsafe { block.as_ptr().cast::<[u8; 64]>().read() }
18}
19
20/// Create an instance from internal representation
21#[inline(always)]
22#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
23pub fn bytes_to_block(bytes: &[u8; 64]) -> ChaCha8Block {
24    // SAFETY: Same size, all bit patterns are valid
25    unsafe { bytes.as_ptr().cast::<ChaCha8Block>().read_unaligned() }
26}
27
28/// State of ChaCha8 cipher
29#[derive(Debug, Copy, Clone)]
30pub struct ChaCha8State {
31    data: ChaCha8Block,
32}
33
34impl ChaCha8State {
35    const ROUNDS: usize = 8;
36
37    /// Initialize ChaCha8 state
38    #[inline]
39    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
40    pub fn init(key: &[u8; 32], nonce: &[u8; 12]) -> Self {
41        let mut data = [0u32; 16];
42        data[0] = 0x61707865;
43        data[1] = 0x3320646e;
44        data[2] = 0x79622d32;
45        data[3] = 0x6b206574;
46
47        for (i, &chunk) in key.array_chunks::<4>().enumerate() {
48            data[4 + i] = u32::from_le_bytes(chunk);
49        }
50
51        // `data[12]` and `data[13]` is counter specific to each block, thus not set here
52
53        for (i, &chunk) in nonce.array_chunks::<4>().enumerate() {
54            data[13 + i] = u32::from_le_bytes(chunk);
55        }
56
57        Self { data }
58    }
59
60    /// Convert to internal representation
61    #[inline(always)]
62    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
63    pub fn to_repr(self) -> ChaCha8Block {
64        self.data
65    }
66
67    /// Create an instance from internal representation
68    #[inline(always)]
69    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
70    pub fn from_repr(data: ChaCha8Block) -> Self {
71        Self { data }
72    }
73
74    /// Compute block for specified counter.
75    ///
76    /// Counter is only 32-bit because that is all that is needed for target use case.
77    #[inline(always)]
78    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
79    pub fn compute_block(mut self, counter: u32) -> ChaCha8Block {
80        self.data[12] = counter;
81        // Not setting `data[13]` due to counter being limited to `u32`
82
83        let initial = self.data;
84
85        for _ in 0..Self::ROUNDS / 2 {
86            self.quarter_round(0, 4, 8, 12);
87            self.quarter_round(1, 5, 9, 13);
88            self.quarter_round(2, 6, 10, 14);
89            self.quarter_round(3, 7, 11, 15);
90
91            self.quarter_round(0, 5, 10, 15);
92            self.quarter_round(1, 6, 11, 12);
93            self.quarter_round(2, 7, 8, 13);
94            self.quarter_round(3, 4, 9, 14);
95        }
96
97        // TODO: More idiomatic version currently doesn't compile:
98        //  https://github.com/Rust-GPU/rust-gpu/issues/241#issuecomment-3005693043
99        #[allow(clippy::needless_range_loop)]
100        // for (d, initial) in self.data.iter_mut().zip(initial) {
101        //     *d = d.wrapping_add(initial);
102        // }
103        for i in 0..16 {
104            self.data[i] = self.data[i].wrapping_add(initial[i]);
105        }
106
107        self.data
108    }
109
110    #[inline(always)]
111    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
112    fn quarter_round(&mut self, a: usize, b: usize, c: usize, d: usize) {
113        self.data[a] = self.data[a].wrapping_add(self.data[b]);
114        self.data[d] ^= self.data[a];
115        self.data[d] = self.data[d].rotate_left(16);
116
117        self.data[c] = self.data[c].wrapping_add(self.data[d]);
118        self.data[b] ^= self.data[c];
119        self.data[b] = self.data[b].rotate_left(12);
120
121        self.data[a] = self.data[a].wrapping_add(self.data[b]);
122        self.data[d] ^= self.data[a];
123        self.data[d] = self.data[d].rotate_left(8);
124
125        self.data[c] = self.data[c].wrapping_add(self.data[d]);
126        self.data[b] ^= self.data[c];
127        self.data[b] = self.data[b].rotate_left(7);
128    }
129}