Skip to main content

ab_riscv_interpreter/rv64/zk/zkn/zkne/
rv64_zkne_helpers.rs

1//! Opaque helpers for RV64 Zkne extension
2
3cfg_select! {
4    all(
5        not(miri),
6        target_arch = "riscv64",
7        target_feature = "zkne"
8    ) => {
9        // Nothing, calling native intrinsics
10    }
11    all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
12        /// x86-64 AES-NI implementation
13        mod x86_64 {
14            use core::arch::x86_64::{
15                _mm_aesenclast_si128, _mm_aesenc_si128, _mm_extract_epi64,
16                _mm_set_epi64x, _mm_setzero_si128,
17            };
18
19            /// `_mm_aesenclast_si128(state, zero)` computes ShiftRows + SubBytes then XORs with
20            /// the round key. Zero key -> no-op XOR, matching `aes64es`.
21            #[inline]
22            #[target_feature(enable = "aes,sse4.1")]
23            pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
24                let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
25                let zero = _mm_setzero_si128();
26                let result = _mm_aesenclast_si128(state, zero);
27                _mm_extract_epi64::<0>(result).cast_unsigned()
28            }
29
30            /// `_mm_aesenc_si128(state, zero)` computes ShiftRows + SubBytes + MixColumns then
31            /// XORs with the round key. Zero key -> no-op XOR, matching `aes64esm`.
32            #[inline]
33            #[target_feature(enable = "aes,sse4.1")]
34            pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
35                let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
36                let zero = _mm_setzero_si128();
37                let result = _mm_aesenc_si128(state, zero);
38                _mm_extract_epi64::<0>(result).cast_unsigned()
39            }
40        }
41    }
42    all(target_arch = "aarch64", target_feature = "aes") => {
43        /// AArch64 AES implementation
44        ///
45        /// AESE XORs the round key first, then applies SubBytes + ShiftRows (note: ARM ShiftRows
46        /// direction matches the forward cipher). With a zero round key the XOR is a no-op,
47        /// leaving pure SubBytes + ShiftRows - identical to what `aes64es` requires.
48        mod aarch64 {
49            use core::arch::aarch64::{
50                vaeseq_u8, vaesmcq_u8, vcombine_u64, vcreate_u64, vdupq_n_u8, vgetq_lane_u64,
51                vreinterpretq_u8_u64, vreinterpretq_u64_u8,
52            };
53
54            #[inline]
55            #[target_feature(enable = "aes")]
56            pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
57                let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
58                let zero = vdupq_n_u8(0);
59                let result = vaeseq_u8(state, zero);
60                vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
61            }
62
63            /// `vaesmcq_u8(vaeseq_u8(state, zero))` maps exactly to `aes64esm`
64            #[inline]
65            #[target_feature(enable = "aes")]
66            pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
67                let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
68                let zero = vdupq_n_u8(0);
69                let after_sub_shift = vaeseq_u8(state, zero);
70                let result = vaesmcq_u8(after_sub_shift);
71                vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
72            }
73        }
74    }
75    _ => {
76        /// Software fallback for aes64es, aes64esm
77        mod soft {
78            use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{SBOX, gmul};
79
80            #[inline(always)]
81            fn mix_col(col: u32) -> u32 {
82                let s0 = col as u8;
83                let s1 = (col >> 8) as u8;
84                let s2 = (col >> 16) as u8;
85                let s3 = (col >> 24) as u8;
86                let r0 = gmul(s0, 0x02) ^ gmul(s1, 0x03) ^ s2 ^ s3;
87                let r1 = s0 ^ gmul(s1, 0x02) ^ gmul(s2, 0x03) ^ s3;
88                let r2 = s0 ^ s1 ^ gmul(s2, 0x02) ^ gmul(s3, 0x03);
89                let r3 = gmul(s0, 0x03) ^ s1 ^ s2 ^ gmul(s3, 0x02);
90                (r0 as u32) | ((r1 as u32) << 8) | ((r2 as u32) << 16) | ((r3 as u32) << 24)
91            }
92
93            /// Apply ShiftRows + SubBytes to the full 128-bit state `(rs1, rs2)` and return the
94            /// low 64-bit half of the result.
95            ///
96            /// State layout: column-major, little-endian 64-bit halves.
97            /// ShiftRows shifts row `r` left by `r` columns (cyclically over 4).
98            /// Output low half contains post-transform columns 0 and 1.
99            #[inline(always)]
100            pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
101                let state_byte = |col: usize, row: usize| -> u8 {
102                    let word = if col < 2 { rs1 } else { rs2 };
103                    (word >> ((col % 2) * 32 + row * 8)) as u8
104                };
105
106                let mut out = 0;
107                for c in 0..2usize {
108                    for r in 0..4usize {
109                        let src_col = (c + r) & 3;
110                        let b = SBOX[state_byte(src_col, r) as usize];
111                        out |= (b as u64) << (c * 32 + r * 8);
112                    }
113                }
114                out
115            }
116
117            #[inline(always)]
118            pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
119                let lo = aes64es(rs1, rs2);
120                let col0 = mix_col(lo as u32);
121                let col1 = mix_col((lo >> 32) as u32);
122                (col0 as u64) | ((col1 as u64) << 32)
123            }
124        }
125    }
126}
127
128#[inline(always)]
129#[doc(hidden)]
130pub fn aes64es(rs1: u64, rs2: u64) -> u64 {
131    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
132    cfg_select! {
133        all(
134            not(miri),
135            target_arch = "riscv64",
136            target_feature = "zkne"
137        ) => {
138            // SAFETY: Compile-time checked for supported feature
139            unsafe {
140                core::arch::riscv64::aes64es(rs1, rs2)
141            }
142        }
143        all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
144            // SAFETY: Compile-time checked for supported feature
145            unsafe {
146                x86_64::aes64es(rs1, rs2)
147            }
148        }
149        all(target_arch = "aarch64", target_feature = "aes") => {
150            // SAFETY: Compile-time checked for supported feature
151            unsafe {
152                aarch64::aes64es(rs1, rs2)
153            }
154        }
155        _ => { soft::aes64es(rs1, rs2) }
156    }
157}
158
159#[inline(always)]
160#[doc(hidden)]
161pub fn aes64esm(rs1: u64, rs2: u64) -> u64 {
162    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
163    cfg_select! {
164        all(
165            not(miri),
166            target_arch = "riscv64",
167            target_feature = "zkne"
168        ) => {
169            // SAFETY: Compile-time checked for supported feature
170            unsafe {
171                core::arch::riscv64::aes64esm(rs1, rs2)
172            }
173        }
174        all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
175            // SAFETY: Compile-time checked for supported feature
176            unsafe {
177                x86_64::aes64esm(rs1, rs2)
178            }
179        }
180        all(target_arch = "aarch64", target_feature = "aes") => {
181            // SAFETY: Compile-time checked for supported feature
182            unsafe {
183                aarch64::aes64esm(rs1, rs2)
184            }
185        }
186        _ => { soft::aes64esm(rs1, rs2) }
187    }
188}