ab_riscv_interpreter/rv64/zk/zkn/zkne/
rv64_zkne_helpers.rs1cfg_select! {
4 all(
5 not(miri),
6 target_arch = "riscv64",
7 target_feature = "zkne"
8 ) => {
9 }
11 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
12 mod x86_64 {
14 use core::arch::x86_64::{
15 _mm_aesenclast_si128, _mm_aesenc_si128, _mm_extract_epi64,
16 _mm_set_epi64x, _mm_setzero_si128,
17 };
18
19 #[inline]
22 #[target_feature(enable = "aes,sse4.1")]
23 pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
24 let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
25 let zero = _mm_setzero_si128();
26 let result = _mm_aesenclast_si128(state, zero);
27 _mm_extract_epi64::<0>(result).cast_unsigned()
28 }
29
30 #[inline]
33 #[target_feature(enable = "aes,sse4.1")]
34 pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
35 let state = _mm_set_epi64x(rs2.cast_signed(), rs1.cast_signed());
36 let zero = _mm_setzero_si128();
37 let result = _mm_aesenc_si128(state, zero);
38 _mm_extract_epi64::<0>(result).cast_unsigned()
39 }
40 }
41 }
42 all(target_arch = "aarch64", target_feature = "aes") => {
43 mod aarch64 {
49 use core::arch::aarch64::{
50 vaeseq_u8, vaesmcq_u8, vcombine_u64, vcreate_u64, vdupq_n_u8, vgetq_lane_u64,
51 vreinterpretq_u8_u64, vreinterpretq_u64_u8,
52 };
53
54 #[inline]
55 #[target_feature(enable = "aes")]
56 pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
57 let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
58 let zero = vdupq_n_u8(0);
59 let result = vaeseq_u8(state, zero);
60 vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
61 }
62
63 #[inline]
65 #[target_feature(enable = "aes")]
66 pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
67 let state = vreinterpretq_u8_u64(vcombine_u64(vcreate_u64(rs1), vcreate_u64(rs2)));
68 let zero = vdupq_n_u8(0);
69 let after_sub_shift = vaeseq_u8(state, zero);
70 let result = vaesmcq_u8(after_sub_shift);
71 vgetq_lane_u64::<0>(vreinterpretq_u64_u8(result))
72 }
73 }
74 }
75 _ => {
76 mod soft {
78 use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{SBOX, gmul};
79
80 #[inline(always)]
81 fn mix_col(col: u32) -> u32 {
82 let s0 = col as u8;
83 let s1 = (col >> 8) as u8;
84 let s2 = (col >> 16) as u8;
85 let s3 = (col >> 24) as u8;
86 let r0 = gmul(s0, 0x02) ^ gmul(s1, 0x03) ^ s2 ^ s3;
87 let r1 = s0 ^ gmul(s1, 0x02) ^ gmul(s2, 0x03) ^ s3;
88 let r2 = s0 ^ s1 ^ gmul(s2, 0x02) ^ gmul(s3, 0x03);
89 let r3 = gmul(s0, 0x03) ^ s1 ^ s2 ^ gmul(s3, 0x02);
90 (r0 as u32) | ((r1 as u32) << 8) | ((r2 as u32) << 16) | ((r3 as u32) << 24)
91 }
92
93 #[inline(always)]
100 pub(super) fn aes64es(rs1: u64, rs2: u64) -> u64 {
101 let state_byte = |col: usize, row: usize| -> u8 {
102 let word = if col < 2 { rs1 } else { rs2 };
103 (word >> ((col % 2) * 32 + row * 8)) as u8
104 };
105
106 let mut out = 0;
107 for c in 0..2usize {
108 for r in 0..4usize {
109 let src_col = (c + r) & 3;
110 let b = SBOX[state_byte(src_col, r) as usize];
111 out |= (b as u64) << (c * 32 + r * 8);
112 }
113 }
114 out
115 }
116
117 #[inline(always)]
118 pub(super) fn aes64esm(rs1: u64, rs2: u64) -> u64 {
119 let lo = aes64es(rs1, rs2);
120 let col0 = mix_col(lo as u32);
121 let col1 = mix_col((lo >> 32) as u32);
122 (col0 as u64) | ((col1 as u64) << 32)
123 }
124 }
125 }
126}
127
128#[inline(always)]
129#[doc(hidden)]
130pub fn aes64es(rs1: u64, rs2: u64) -> u64 {
131 cfg_select! {
133 all(
134 not(miri),
135 target_arch = "riscv64",
136 target_feature = "zkne"
137 ) => {
138 unsafe {
140 core::arch::riscv64::aes64es(rs1, rs2)
141 }
142 }
143 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
144 unsafe {
146 x86_64::aes64es(rs1, rs2)
147 }
148 }
149 all(target_arch = "aarch64", target_feature = "aes") => {
150 unsafe {
152 aarch64::aes64es(rs1, rs2)
153 }
154 }
155 _ => { soft::aes64es(rs1, rs2) }
156 }
157}
158
159#[inline(always)]
160#[doc(hidden)]
161pub fn aes64esm(rs1: u64, rs2: u64) -> u64 {
162 cfg_select! {
164 all(
165 not(miri),
166 target_arch = "riscv64",
167 target_feature = "zkne"
168 ) => {
169 unsafe {
171 core::arch::riscv64::aes64esm(rs1, rs2)
172 }
173 }
174 all(target_arch = "x86_64", target_feature = "aes", target_feature = "sse4.1") => {
175 unsafe {
177 x86_64::aes64esm(rs1, rs2)
178 }
179 }
180 all(target_arch = "aarch64", target_feature = "aes") => {
181 unsafe {
183 aarch64::aes64esm(rs1, rs2)
184 }
185 }
186 _ => { soft::aes64esm(rs1, rs2) }
187 }
188}