ab_riscv_interpreter/rv32/zk/zkn/zkne/rv32_zkne_helpers.rs
1//! Opaque helpers for RV32 Zkne extension
2
3use ab_riscv_primitives::prelude::*;
4
5/// Software fallback for aes32esi and aes32esmi.
6///
7/// Both instructions share the same S-box and MixColumn machinery; the only difference is whether
8/// forward MixColumns is applied.
9#[cfg(not(all(not(miri), target_arch = "riscv32", target_feature = "zkne")))]
10pub(in super::super) mod soft {
11 use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{SBOX, gmul};
12
13 /// Compute the partial forward MixColumns contribution for a single substituted byte `b`.
14 ///
15 /// This is `aes_mixcolumn_byte_fwd` from the Sail reference:
16 /// the four output bytes of MixColumns when the input column has `b` in one position and zeros
17 /// elsewhere - packed into a little-endian `u32`.
18 ///
19 /// Column matrix multiply for MixColumns:
20 /// ```text
21 /// r0 = 0x02*b
22 /// r1 = 0x01*b
23 /// r2 = 0x01*b
24 /// r3 = 0x03*b
25 /// ```
26 #[inline(always)]
27 pub(super) fn mix_col_byte(b: u8) -> u32 {
28 let r0 = u32::from(gmul(b, 0x02));
29 let r1 = u32::from(b);
30 let r2 = u32::from(b);
31 let r3 = u32::from(gmul(b, 0x03));
32 r0 | (r1 << 8) | (r2 << 16) | (r3 << 24)
33 }
34
35 /// `aes32esi rs1, rs2, bs`
36 ///
37 /// Pseudocode:
38 /// ```text
39 /// shamt = bs * 8
40 /// si = (rs2 >> shamt) & 0xff
41 /// so = SBOX[si] as u32
42 /// rd = rs1 ^ rol32(so, shamt)
43 /// ```
44 #[inline(always)]
45 pub(super) fn aes32esi(rs1: u32, rs2: u32, bs: u8) -> u32 {
46 let shamt = u32::from(bs) * 8;
47 let si = ((rs2 >> shamt) & 0xff) as u8;
48 let so = u32::from(SBOX[usize::from(si)]);
49 rs1 ^ so.rotate_left(shamt)
50 }
51
52 /// `aes32esmi rs1, rs2, bs`
53 ///
54 /// Pseudocode:
55 /// ```text
56 /// shamt = bs * 8
57 /// si = (rs2 >> shamt) & 0xff
58 /// so = SBOX[si]
59 /// mixed = mix_col_byte(so)
60 /// rd = rs1 ^ rol32(mixed, shamt)
61 /// ```
62 #[inline(always)]
63 pub(super) fn aes32esmi(rs1: u32, rs2: u32, bs: u8) -> u32 {
64 let shamt = u32::from(bs) * 8;
65 let si = ((rs2 >> shamt) & 0xff) as u8;
66 let so = SBOX[usize::from(si)];
67 let mixed = mix_col_byte(so);
68 rs1 ^ mixed.rotate_left(shamt)
69 }
70}
71
72#[inline(always)]
73#[doc(hidden)]
74pub fn aes32esi(rs1: u32, rs2: u32, bs: Rv32AesBs) -> u32 {
75 // TODO: Miri is excluded because corresponding intrinsic is not implemented there
76 cfg_select! {
77 all(
78 not(miri),
79 target_arch = "riscv32",
80 target_feature = "zkne"
81 ) => {
82 // SAFETY: Compile-time checked for supported feature
83 unsafe {
84 core::arch::riscv32::aes32esi(rs1, rs2, u8::from(bs))
85 }
86 }
87 _ => { soft::aes32esi(rs1, rs2, u8::from(bs)) }
88 }
89}
90
91#[inline(always)]
92#[doc(hidden)]
93pub fn aes32esmi(rs1: u32, rs2: u32, bs: Rv32AesBs) -> u32 {
94 // TODO: Miri is excluded because corresponding intrinsic is not implemented there
95 cfg_select! {
96 all(
97 not(miri),
98 target_arch = "riscv32",
99 target_feature = "zkne"
100 ) => {
101 // SAFETY: Compile-time checked for supported feature
102 unsafe {
103 core::arch::riscv32::aes32esmi(rs1, rs2, u8::from(bs))
104 }
105 }
106 _ => { soft::aes32esmi(rs1, rs2, u8::from(bs)) }
107 }
108}