ab_riscv_interpreter/rv64/zk/zkn/zknd/
rv64_zknd_helpers.rs1use ab_riscv_primitives::prelude::*;
4
5#[cfg(not(all(not(miri), target_arch = "riscv64", target_feature = "zknd")))]
9mod ks {
10 use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::SBOX;
11 use ab_riscv_primitives::prelude::*;
12
13 #[inline(always)]
24 pub(super) fn aes64ks1i(rs1: u64, rnum: Rv64ZkndKsRnum) -> u64 {
25 let w = (rs1 >> 32) as u32;
26
27 let rotated = if rnum != Rv64ZkndKsRnum::Final {
28 w.rotate_right(8)
29 } else {
30 w
31 };
32
33 let b0 = u32::from(SBOX[(rotated & 0xff) as usize]);
34 let b1 = u32::from(SBOX[((rotated >> 8) & 0xff) as usize]);
35 let b2 = u32::from(SBOX[((rotated >> 16) & 0xff) as usize]);
36 let b3 = u32::from(SBOX[((rotated >> 24) & 0xff) as usize]);
37 let subbed = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
38
39 let result = if let Some(round_constant) = rnum.constant() {
40 subbed ^ u32::from(round_constant)
41 } else {
42 subbed
43 };
44
45 u64::from(result) | (u64::from(result) << 32)
46 }
47
48 #[inline(always)]
57 pub(super) fn aes64ks2(rs1: u64, rs2: u64) -> u64 {
58 let w0 = (rs1 >> 32) as u32 ^ rs2 as u32;
59 let w1 = w0 ^ (rs2 >> 32) as u32;
60 u64::from(w0) | (u64::from(w1) << 32)
61 }
62}
63
64cfg_select! {
65 all(
66 not(miri),
67 target_arch = "riscv64",
68 target_feature = "zknd"
69 ) => {
70 }
72 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
73 mod x86_64 {
75 use core::arch::x86_64::{
76 _mm_aesdec_si128, _mm_aesdeclast_si128, _mm_aesimc_si128, _mm_extract_epi64,
77 _mm_set_epi64x, _mm_setzero_si128,
78 };
79
80 #[inline]
83 #[target_feature(enable = "aes,sse4.1")]
84 pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
85 let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
86 let zero = _mm_setzero_si128();
87 let result = _mm_aesdeclast_si128(state, zero);
88 _mm_extract_epi64::<0>(result).cast_unsigned()
89 }
90
91 #[inline]
94 #[target_feature(enable = "aes,sse4.1")]
95 pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
96 let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
97 let zero = _mm_setzero_si128();
98 let result = _mm_aesdec_si128(state, zero);
99 _mm_extract_epi64::<0>(result).cast_unsigned()
100 }
101
102 #[inline]
105 #[target_feature(enable = "aes,sse4.1")]
106 pub(super) fn aes64im(rs1: u64) -> u64 {
107 let state = _mm_set_epi64x(rs1.cast_signed(), rs1.cast_signed());
108 let result = _mm_aesimc_si128(state);
109 _mm_extract_epi64::<0>(result).cast_unsigned()
110 }
111 }
112 }
113 all(target_arch = "aarch64", target_feature = "aes") => {
114 mod aarch64 {
116 use core::arch::aarch64::{
117 vaesdq_u8, vaesimcq_u8, vcombine_u64, vcreate_u64, vdupq_n_u8, vgetq_lane_u64,
118 vreinterpretq_u8_u64, vreinterpretq_u64_u8,
119 };
120
121 #[inline]
125 #[target_feature(enable = "aes")]
126 pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
127 let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
128 let zero = vdupq_n_u8(0);
129 let result = vaesdq_u8(state, zero);
130 vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
131 }
132
133 #[inline]
135 #[target_feature(enable = "aes")]
136 pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
137 let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
138 let zero = vdupq_n_u8(0);
139 let after_sub_shift = vaesdq_u8(state, zero);
140 let result = vaesimcq_u8(after_sub_shift);
141 vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
142 }
143
144 #[inline]
145 #[target_feature(enable = "aes")]
146 pub(super) fn aes64im(rs1: u64) -> u64 {
147 let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs1)));
148 let result = vaesimcq_u8(state);
149 vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
150 }
151 }
152 }
153 _ => {
154 mod soft {
156 use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{INV_SBOX, gmul};
157
158 #[inline(always)]
159 fn inv_mix_col(col: u32) -> u32 {
160 let s0 = col as u8;
161 let s1 = (col >> 8) as u8;
162 let s2 = (col >> 16) as u8;
163 let s3 = (col >> 24) as u8;
164 let r0 = gmul(s0, 0x0e) ^ gmul(s1, 0x0b) ^ gmul(s2, 0x0d) ^ gmul(s3, 0x09);
165 let r1 = gmul(s0, 0x09) ^ gmul(s1, 0x0e) ^ gmul(s2, 0x0b) ^ gmul(s3, 0x0d);
166 let r2 = gmul(s0, 0x0d) ^ gmul(s1, 0x09) ^ gmul(s2, 0x0e) ^ gmul(s3, 0x0b);
167 let r3 = gmul(s0, 0x0b) ^ gmul(s1, 0x0d) ^ gmul(s2, 0x09) ^ gmul(s3, 0x0e);
168 (r0 as u32) | ((r1 as u32) << 8) | ((r2 as u32) << 16) | ((r3 as u32) << 24)
169 }
170
171 #[inline(always)]
181 pub(super) fn aes64ds(rs1: u64, rs2: u64) -> u64 {
182 let state_byte = |col: usize, row: usize| -> u8 {
183 let word = if col < 2 { rs1 } else { rs2 };
184 (word >> ((col % 2) * 32 + row * 8)) as u8
185 };
186
187 let mut out = 0;
188 for c in 0..2usize {
189 for r in 0..4usize {
190 let src_col = (c + 4 - r) & 3;
191 let b = INV_SBOX[state_byte(src_col, r) as usize];
192 out |= (b as u64) << (c * 32 + r * 8);
193 }
194 }
195 out
196 }
197
198 #[inline(always)]
199 pub(super) fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
200 let lo = aes64ds(rs1, rs2);
201 let col0 = inv_mix_col(lo as u32);
202 let col1 = inv_mix_col((lo >> 32) as u32);
203 (col0 as u64) | ((col1 as u64) << 32)
204 }
205
206 #[inline(always)]
207 pub(super) fn aes64im(rs1: u64) -> u64 {
208 let col0 = inv_mix_col(rs1 as u32);
209 let col1 = inv_mix_col((rs1 >> 32) as u32);
210 (col0 as u64) | ((col1 as u64) << 32)
211 }
212 }
213 }
214}
215
216#[inline(always)]
217#[doc(hidden)]
218pub fn aes64ds(rs1: u64, rs2: u64) -> u64 {
219 cfg_select! {
221 all(
222 not(miri),
223 target_arch = "riscv64",
224 target_feature = "zknd"
225 ) => {
226 unsafe {
228 core::arch::riscv64::aes64ds(rs1, rs2)
229 }
230 }
231 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
232 unsafe {
234 x86_64::aes64ds(rs1, rs2)
235 }
236 }
237 all(target_arch = "aarch64", target_feature = "aes") => {
238 unsafe {
240 aarch64::aes64ds(rs1, rs2)
241 }
242 }
243 _ => { soft::aes64ds(rs1, rs2) }
244 }
245}
246
247#[inline(always)]
248#[doc(hidden)]
249pub fn aes64dsm(rs1: u64, rs2: u64) -> u64 {
250 cfg_select! {
252 all(
253 not(miri),
254 target_arch = "riscv64",
255 target_feature = "zknd"
256 ) => {
257 unsafe {
259 core::arch::riscv64::aes64dsm(rs1, rs2)
260 }
261 }
262 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
263 unsafe {
265 x86_64::aes64dsm(rs1, rs2)
266 }
267 }
268 all(target_arch = "aarch64", target_feature = "aes") => {
269 unsafe {
271 aarch64::aes64dsm(rs1, rs2)
272 }
273 }
274 _ => { soft::aes64dsm(rs1, rs2) }
275 }
276}
277
278#[inline(always)]
279#[doc(hidden)]
280pub fn aes64im(rs1: u64) -> u64 {
281 cfg_select! {
283 all(
284 not(miri),
285 target_arch = "riscv64",
286 target_feature = "zknd"
287 ) => {
288 unsafe {
290 core::arch::riscv64::aes64im(rs1)
291 }
292 }
293 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
294 unsafe {
296 x86_64::aes64im(rs1)
297 }
298 }
299 all(target_arch = "aarch64", target_feature = "aes") => {
300 unsafe {
302 aarch64::aes64im(rs1)
303 }
304 }
305 _ => { soft::aes64im(rs1) }
306 }
307}
308
309#[inline(always)]
310#[doc(hidden)]
311pub fn aes64ks1i(rs1: u64, rnum: Rv64ZkndKsRnum) -> u64 {
312 cfg_select! {
314 all(
315 not(miri),
316 target_arch = "riscv64",
317 target_feature = "zknd"
318 ) => {
319 unsafe {
321 core::arch::riscv64::aes64ks1i(rs1, rnum as u8)
322 }
323 }
324 _ => { ks::aes64ks1i(rs1, rnum) }
325 }
326}
327
328#[inline(always)]
329#[doc(hidden)]
330pub fn aes64ks2(rs1: u64, rs2: u64) -> u64 {
331 cfg_select! {
333 all(
334 not(miri),
335 target_arch = "riscv64",
336 target_feature = "zknd"
337 ) => {
338 unsafe {
340 core::arch::riscv64::aes64ks2(rs1, rs2)
341 }
342 }
343 _ => { ks::aes64ks2(rs1, rs2) }
344 }
345}