Skip to main content

ab_riscv_interpreter/rv64/zk/zkn/zknd/
rv64_zknd_helpers.rs

1//! Opaque helpers for RV64 Zknd extension
2
3use ab_riscv_primitives::prelude::*;
4
5/// Key schedule operations shared across all backends.
6///
7/// Neither `aes64ks1i` nor `aes64ks2` has a hardware mapping on non-riscv64.
8#[cfg(not(all(not(miri), target_arch = "riscv64", target_feature = "zknd")))]
9mod ks {
10    use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::SBOX;
11    use ab_riscv_primitives::prelude::*;
12
13    /// AES key schedule step 1.
14    ///
15    /// Pseudocode (RISC-V Crypto spec Sail source):
16    /// ```text
17    ///   temp = rs1[63:32]
18    ///   if rnum != 0xA: temp = RotWord(temp)
19    ///   temp = SubWord(temp)
20    ///   if rnum != 0xA: temp ^= RCON[rnum]
21    ///   rd = temp | (temp << 32)
22    /// ```
23    #[inline(always)]
24    pub(super) fn aes64ks1i(rs1: u64, rnum: Rv64ZkndKsRnum) -> u64 {
25        let w = (rs1 >> 32) as u32;
26
27        let rotated = if rnum != Rv64ZkndKsRnum::Final {
28            w.rotate_right(8)
29        } else {
30            w
31        };
32
33        let b0 = u32::from(SBOX[(rotated & 0xff) as usize]);
34        let b1 = u32::from(SBOX[((rotated >> 8) & 0xff) as usize]);
35        let b2 = u32::from(SBOX[((rotated >> 16) & 0xff) as usize]);
36        let b3 = u32::from(SBOX[((rotated >> 24) & 0xff) as usize]);
37        let subbed = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
38
39        let result = if let Some(round_constant) = rnum.constant() {
40            subbed ^ u32::from(round_constant)
41        } else {
42            subbed
43        };
44
45        u64::from(result) | (u64::from(result) << 32)
46    }
47
48    /// AES key schedule step 2.
49    ///
50    /// Pseudocode (RISC-V Crypto spec):
51    /// ```text
52    ///   w0 = rs1[63:32] ^ rs2[31:0]
53    ///   w1 = rs1[63:32] ^ rs2[31:0] ^ rs2[63:32]
54    ///   rd = w0 | (w1 << 32)
55    /// ```
56    #[inline(always)]
57    pub(super) fn aes64ks2(rs1: u64, rs2: u64) -> u64 {
58        let w0 = (rs1 >> 32) as u32 ^ rs2 as u32;
59        let w1 = w0 ^ (rs2 >> 32) as u32;
60        u64::from(w0) | (u64::from(w1) << 32)
61    }
62}
63
64cfg_select! {
65    all(
66        not(miri),
67        target_arch = "riscv64",
68        target_feature = "zknd"
69    ) => {
70        // Nothing, calling native intrinsics
71    }
72    all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
73        /// x86-64 AES-NI implementation
74        mod x86_64 {
75            use core::arch::x86_64::{
76                _mm_aesdec_si128, _mm_aesdeclast_si128, _mm_aesimc_si128, _mm_extract_epi64,
77                _mm_set_epi64x, _mm_setzero_si128,
78            };
79
80            /// `_mm_aesdeclast_si128(state, zero)` computes InvShiftRows + InvSubBytes, then XORs
81            /// with the round key. Zero key -> no-op XOR, matching `aes64ds`.
82            #[inline]
83            #[target_feature(enable = "aes,sse4.1")]
84            pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
85                let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
86                let zero = _mm_setzero_si128();
87                let result = _mm_aesdeclast_si128(state, zero);
88                _mm_extract_epi64::<0>(result).cast_unsigned()
89            }
90
91            /// `_mm_aesdec_si128(state, zero)` computes InvShiftRows + InvSubBytes + InvMixColumns,
92            /// then XORs with the round key. Zero key -> no-op XOR.
93            #[inline]
94            #[target_feature(enable = "aes,sse4.1")]
95            pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
96                let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
97                let zero = _mm_setzero_si128();
98                let result = _mm_aesdec_si128(state, zero);
99                _mm_extract_epi64::<0>(result).cast_unsigned()
100            }
101
102            /// `_mm_aesimc_si128` applies InvMixColumns to all four 32-bit columns.
103            /// `rs1` is replicated into both halves; we extract the low 64 bits.
104            #[inline]
105            #[target_feature(enable = "aes,sse4.1")]
106            pub(super) fn aes64im(rs1: u64) -> u64 {
107                let state = _mm_set_epi64x(rs1.cast_signed(), rs1.cast_signed());
108                let result = _mm_aesimc_si128(state);
109                _mm_extract_epi64::<0>(result).cast_unsigned()
110            }
111        }
112    }
113    all(target_arch = "aarch64", target_feature = "aes") => {
114        /// AArch64 AES implementation
115        mod aarch64 {
116            use core::arch::aarch64::{
117                vaesdq_u8, vaesimcq_u8, vcombine_u64, vcreate_u64, vdupq_n_u8, vgetq_lane_u64,
118                vreinterpretq_u8_u64, vreinterpretq_u64_u8,
119            };
120
121            /// `vaesdq_u8(state, zero)` computes XOR(zero) then InvShiftRows + InvSubBytes. ARM's
122            /// AESD operates in the same byte order as the RISC-V half-state model when
123            /// `(rs1, rs2)` is loaded little-endian; no swap needed.
124            #[inline]
125            #[target_feature(enable = "aes")]
126            pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
127                let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
128                let zero = vdupq_n_u8(0);
129                let result = vaesdq_u8(state, zero);
130                vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
131            }
132
133            /// `vaesimcq_u8(vaesdq_u8(state, zero))` maps exactly to `aes64dsm`
134            #[inline]
135            #[target_feature(enable = "aes")]
136            pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
137                let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
138                let zero = vdupq_n_u8(0);
139                let after_sub_shift = vaesdq_u8(state, zero);
140                let result = vaesimcq_u8(after_sub_shift);
141                vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
142            }
143
144            #[inline]
145            #[target_feature(enable = "aes")]
146            pub(super) fn aes64im(rs1: u64) -> u64 {
147                let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs1)));
148                let result = vaesimcq_u8(state);
149                vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
150            }
151        }
152    }
153    _ => {
154        /// Software fallback for aes64ds, aes64dsm, aes64im
155        mod soft {
156            use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{INV_SBOX, gmul};
157
158            #[inline(always)]
159            fn inv_mix_col(col: u32) -> u32 {
160                let s0 = col as u8;
161                let s1 = (col >> 8) as u8;
162                let s2 = (col >> 16) as u8;
163                let s3 = (col >> 24) as u8;
164                let r0 = gmul(s0, 0x0e) ^ gmul(s1, 0x0b) ^ gmul(s2, 0x0d) ^ gmul(s3, 0x09);
165                let r1 = gmul(s0, 0x09) ^ gmul(s1, 0x0e) ^ gmul(s2, 0x0b) ^ gmul(s3, 0x0d);
166                let r2 = gmul(s0, 0x0d) ^ gmul(s1, 0x09) ^ gmul(s2, 0x0e) ^ gmul(s3, 0x0b);
167                let r3 = gmul(s0, 0x0b) ^ gmul(s1, 0x0d) ^ gmul(s2, 0x09) ^ gmul(s3, 0x0e);
168                (r0 as u32) | ((r1 as u32) << 8) | ((r2 as u32) << 16) | ((r3 as u32) << 24)
169            }
170
171            /// Apply InvShiftRows + InvSubBytes to the full 128-bit state `(rs1, rs2)` and return
172            /// the low 64-bit half of the result.
173            ///
174            /// State layout: column-major, little-endian 64-bit halves.
175            /// `byte[col*4 + row]` is at bit `(row*8)` of `rs1` for `col < 2`, or bit `(row*8)` of
176            /// `rs2` for `col >= 2`.
177            ///
178            /// InvShiftRows shifts row `r` right by `r` columns (cyclically over 4).
179            /// Output low half contains post-transform columns 0 and 1.
180            #[inline(always)]
181            pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
182                let state_byte = |col: usize, row: usize| -> u8 {
183                    let word = if col < 2 { rs1 } else { rs2 };
184                    (word >> ((col % 2) * 32 + row * 8)) as u8
185                };
186
187                let mut out = 0;
188                for c in 0..2usize {
189                    for r in 0..4usize {
190                        let src_col = (c + 4 - r) & 3;
191                        let b = INV_SBOX[state_byte(src_col, r) as usize];
192                        out |= (b as u64) << (c * 32 + r * 8);
193                    }
194                }
195                out
196            }
197
198            #[inline(always)]
199            pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
200                let lo = aes64ds(rs1, rs2);
201                let col0 = inv_mix_col(lo as u32);
202                let col1 = inv_mix_col((lo >> 32) as u32);
203                (col0 as u64) | ((col1 as u64) << 32)
204            }
205
206            #[inline(always)]
207            pub(super) fn aes64im(rs1: u64) -> u64 {
208                let col0 = inv_mix_col(rs1 as u32);
209                let col1 = inv_mix_col((rs1 >> 32) as u32);
210                (col0 as u64) | ((col1 as u64) << 32)
211            }
212        }
213    }
214}
215
216#[inline(always)]
217#[doc(hidden)]
218pub fn aes64ds(rs1: u64, rs2: u64) -> u64 {
219    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
220    cfg_select! {
221        all(
222            not(miri),
223            target_arch = "riscv64",
224            target_feature = "zknd"
225        ) => {
226            // SAFETY: Compile-time checked for supported feature
227            unsafe {
228                core::arch::riscv64::aes64ds(rs1, rs2)
229            }
230        }
231        all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
232            // SAFETY: Compile-time checked for supported feature
233            unsafe {
234                x86_64::aes64ds(rs1, rs2)
235            }
236        }
237        all(target_arch = "aarch64", target_feature = "aes") => {
238            // SAFETY: Compile-time checked for supported feature
239            unsafe {
240                aarch64::aes64ds(rs1, rs2)
241            }
242        }
243        _ => { soft::aes64ds(rs1, rs2) }
244    }
245}
246
247#[inline(always)]
248#[doc(hidden)]
249pub fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
250    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
251    cfg_select! {
252        all(
253            not(miri),
254            target_arch = "riscv64",
255            target_feature = "zknd"
256        ) => {
257            // SAFETY: Compile-time checked for supported feature
258            unsafe {
259                core::arch::riscv64::aes64dsm(rs1, rs2)
260            }
261        }
262        all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
263            // SAFETY: Compile-time checked for supported feature
264            unsafe {
265                x86_64::aes64dsm(rs1, rs2)
266            }
267        }
268        all(target_arch = "aarch64", target_feature = "aes") => {
269            // SAFETY: Compile-time checked for supported feature
270            unsafe {
271                aarch64::aes64dsm(rs1, rs2)
272            }
273        }
274        _ => { soft::aes64dsm(rs1, rs2) }
275    }
276}
277
278#[inline(always)]
279#[doc(hidden)]
280pub fn aes64im(rs1: u64) -> u64 {
281    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
282    cfg_select! {
283        all(
284            not(miri),
285            target_arch = "riscv64",
286            target_feature = "zknd"
287        ) => {
288            // SAFETY: Compile-time checked for supported feature
289            unsafe {
290                core::arch::riscv64::aes64im(rs1)
291            }
292        }
293        all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
294            // SAFETY: Compile-time checked for supported feature
295            unsafe {
296                x86_64::aes64im(rs1)
297            }
298        }
299        all(target_arch = "aarch64", target_feature = "aes") => {
300            // SAFETY: Compile-time checked for supported feature
301            unsafe {
302                aarch64::aes64im(rs1)
303            }
304        }
305        _ => { soft::aes64im(rs1) }
306    }
307}
308
309#[inline(always)]
310#[doc(hidden)]
311pub fn aes64ks1i(rs1: u64, rnum: Rv64ZkndKsRnum) -> u64 {
312    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
313    cfg_select! {
314        all(
315            not(miri),
316            target_arch = "riscv64",
317            target_feature = "zknd"
318        ) => {
319            // SAFETY: Compile-time checked for supported feature
320            unsafe {
321                core::arch::riscv64::aes64ks1i(rs1, rnum as u8)
322            }
323        }
324        _ => { ks::aes64ks1i(rs1, rnum) }
325    }
326}
327
328#[inline(always)]
329#[doc(hidden)]
330pub fn aes64ks2(rs1: u64, rs2: u64) -> u64 {
331    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
332    cfg_select! {
333        all(
334            not(miri),
335            target_arch = "riscv64",
336            target_feature = "zknd"
337        ) => {
338            // SAFETY: Compile-time checked for supported feature
339            unsafe {
340                core::arch::riscv64::aes64ks2(rs1, rs2)
341            }
342        }
343        _ => { ks::aes64ks2(rs1, rs2) }
344    }
345}