ab_chacha8/
lib.rs

1//! Small GPU-friendly software implementation of ChaCha8
2
3#![no_std]
4
5#[cfg(test)]
6mod tests;
7
8/// A single ChaCha8 block
9pub type ChaCha8Block = [u32; 16];
10
11/// Convert block to bytes
12#[inline(always)]
13#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
14pub fn block_to_bytes(block: &ChaCha8Block) -> [u8; 64] {
15    // SAFETY: Same size and no alignment requirements
16    unsafe { block.as_ptr().cast::<[u8; 64]>().read() }
17}
18
19/// Create an instance from internal representation
20#[inline(always)]
21#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
22pub fn bytes_to_block(bytes: &[u8; 64]) -> ChaCha8Block {
23    // SAFETY: Same size, all bit patterns are valid
24    unsafe { bytes.as_ptr().cast::<ChaCha8Block>().read_unaligned() }
25}
26
27/// State of ChaCha8 cipher
28#[derive(Debug, Copy, Clone)]
29pub struct ChaCha8State {
30    data: ChaCha8Block,
31}
32
33impl ChaCha8State {
34    const ROUNDS: usize = 8;
35
36    /// Initialize ChaCha8 state
37    #[inline]
38    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
39    pub fn init(key: &[u8; 32], nonce: &[u8; 12]) -> Self {
40        let mut data = [0u32; 16];
41        data[0] = 0x61707865;
42        data[1] = 0x3320646e;
43        data[2] = 0x79622d32;
44        data[3] = 0x6b206574;
45
46        for (i, &chunk) in key.as_chunks::<4>().0.iter().enumerate() {
47            data[4 + i] = u32::from_le_bytes(chunk);
48        }
49
50        // `data[12]` and `data[13]` is counter specific to each block, thus not set here
51
52        for (i, &chunk) in nonce.as_chunks::<4>().0.iter().enumerate() {
53            data[13 + i] = u32::from_le_bytes(chunk);
54        }
55
56        Self { data }
57    }
58
59    /// Convert to internal representation
60    #[inline(always)]
61    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
62    pub fn to_repr(self) -> ChaCha8Block {
63        self.data
64    }
65
66    /// Create an instance from internal representation
67    #[inline(always)]
68    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
69    pub fn from_repr(data: ChaCha8Block) -> Self {
70        Self { data }
71    }
72
73    /// Compute block for specified counter.
74    ///
75    /// Counter is only 32-bit because that is all that is needed for target use case.
76    #[inline(always)]
77    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
78    pub fn compute_block(mut self, counter: u32) -> ChaCha8Block {
79        self.data[12] = counter;
80        // Not setting `data[13]` due to counter being limited to `u32`
81
82        let initial = self.data;
83
84        for _ in 0..Self::ROUNDS / 2 {
85            self.quarter_round(0, 4, 8, 12);
86            self.quarter_round(1, 5, 9, 13);
87            self.quarter_round(2, 6, 10, 14);
88            self.quarter_round(3, 7, 11, 15);
89
90            self.quarter_round(0, 5, 10, 15);
91            self.quarter_round(1, 6, 11, 12);
92            self.quarter_round(2, 7, 8, 13);
93            self.quarter_round(3, 4, 9, 14);
94        }
95
96        // TODO: More idiomatic version currently doesn't compile:
97        //  https://github.com/Rust-GPU/rust-gpu/issues/241#issuecomment-3005693043
98        #[allow(clippy::needless_range_loop)]
99        // for (d, initial) in self.data.iter_mut().zip(initial) {
100        //     *d = d.wrapping_add(initial);
101        // }
102        for i in 0..16 {
103            self.data[i] = self.data[i].wrapping_add(initial[i]);
104        }
105
106        self.data
107    }
108
109    #[inline(always)]
110    #[cfg_attr(feature = "no-panic", no_panic::no_panic)]
111    fn quarter_round(&mut self, a: usize, b: usize, c: usize, d: usize) {
112        self.data[a] = self.data[a].wrapping_add(self.data[b]);
113        self.data[d] ^= self.data[a];
114        self.data[d] = self.data[d].rotate_left(16);
115
116        self.data[c] = self.data[c].wrapping_add(self.data[d]);
117        self.data[b] ^= self.data[c];
118        self.data[b] = self.data[b].rotate_left(12);
119
120        self.data[a] = self.data[a].wrapping_add(self.data[b]);
121        self.data[d] ^= self.data[a];
122        self.data[d] = self.data[d].rotate_left(8);
123
124        self.data[c] = self.data[c].wrapping_add(self.data[d]);
125        self.data[b] ^= self.data[c];
126        self.data[b] = self.data[b].rotate_left(7);
127    }
128}