Skip to main content

ab_riscv_interpreter/rv32/zk/zkn/zkne/
rv32_zkne_helpers.rs

1//! Opaque helpers for RV32 Zkne extension
2
3use ab_riscv_primitives::prelude::*;
4
5/// Software fallback for aes32esi and aes32esmi.
6///
7/// Both instructions share the same S-box and MixColumn machinery; the only difference is whether
8/// forward MixColumns is applied.
9#[cfg(not(all(not(miri), target_arch = "riscv32", target_feature = "zkne")))]
10pub(in super::super) mod soft {
11    use crate::rv32::zk::zkn::zknd::rv32_zknd_helpers::{SBOX, gmul};
12
13    /// Compute the partial forward MixColumns contribution for a single substituted byte `b`.
14    ///
15    /// This is `aes_mixcolumn_byte_fwd` from the Sail reference:
16    /// the four output bytes of MixColumns when the input column has `b` in one position and zeros
17    /// elsewhere - packed into a little-endian `u32`.
18    ///
19    /// Column matrix multiply for MixColumns:
20    /// ```text
21    /// r0 = 0x02*b
22    /// r1 = 0x01*b
23    /// r2 = 0x01*b
24    /// r3 = 0x03*b
25    /// ```
26    #[inline(always)]
27    pub(super) fn mix_col_byte(b: u8) -> u32 {
28        let r0 = u32::from(gmul(b, 0x02));
29        let r1 = u32::from(b);
30        let r2 = u32::from(b);
31        let r3 = u32::from(gmul(b, 0x03));
32        r0 | (r1 << 8) | (r2 << 16) | (r3 << 24)
33    }
34
35    /// `aes32esi rs1, rs2, bs`
36    ///
37    /// Pseudocode:
38    /// ```text
39    /// shamt = bs * 8
40    /// si    = (rs2 >> shamt) & 0xff
41    /// so    = SBOX[si] as u32
42    /// rd    = rs1 ^ rol32(so, shamt)
43    /// ```
44    #[inline(always)]
45    pub(super) fn aes32esi(rs1: u32, rs2: u32, bs: u8) -> u32 {
46        let shamt = u32::from(bs) * 8;
47        let si = ((rs2 >> shamt) & 0xff) as u8;
48        let so = u32::from(SBOX[usize::from(si)]);
49        rs1 ^ so.rotate_left(shamt)
50    }
51
52    /// `aes32esmi rs1, rs2, bs`
53    ///
54    /// Pseudocode:
55    /// ```text
56    /// shamt = bs * 8
57    /// si    = (rs2 >> shamt) & 0xff
58    /// so    = SBOX[si]
59    /// mixed = mix_col_byte(so)
60    /// rd    = rs1 ^ rol32(mixed, shamt)
61    /// ```
62    #[inline(always)]
63    pub(super) fn aes32esmi(rs1: u32, rs2: u32, bs: u8) -> u32 {
64        let shamt = u32::from(bs) * 8;
65        let si = ((rs2 >> shamt) & 0xff) as u8;
66        let so = SBOX[usize::from(si)];
67        let mixed = mix_col_byte(so);
68        rs1 ^ mixed.rotate_left(shamt)
69    }
70}
71
72#[inline(always)]
73#[doc(hidden)]
74pub fn aes32esi(rs1: u32, rs2: u32, bs: Rv32AesBs) -> u32 {
75    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
76    cfg_select! {
77        all(
78            not(miri),
79            target_arch = "riscv32",
80            target_feature = "zkne"
81        ) => {
82            // SAFETY: Compile-time checked for supported feature
83            unsafe {
84                core::arch::riscv32::aes32esi(rs1, rs2, u8::from(bs))
85            }
86        }
87        _ => { soft::aes32esi(rs1, rs2, u8::from(bs)) }
88    }
89}
90
91#[inline(always)]
92#[doc(hidden)]
93pub fn aes32esmi(rs1: u32, rs2: u32, bs: Rv32AesBs) -> u32 {
94    // TODO: Miri is excluded because corresponding intrinsic is not implemented there
95    cfg_select! {
96        all(
97            not(miri),
98            target_arch = "riscv32",
99            target_feature = "zkne"
100        ) => {
101            // SAFETY: Compile-time checked for supported feature
102            unsafe {
103                core::arch::riscv32::aes32esmi(rs1, rs2, u8::from(bs))
104            }
105        }
106        _ => { soft::aes32esmi(rs1, rs2, u8::from(bs)) }
107    }
108}