Skip to main content

ab_riscv_primitives/instructions/rv32/zk/zkn/
zknd.rs

1//! RV32 Zknd extension
2
3#[cfg(test)]
4mod tests;
5
6use crate::instructions::Instruction;
7use crate::registers::general_purpose::Register;
8use ab_riscv_macros::instruction;
9use core::fmt;
10
11/// 2-bit byte-select immediate for RV32 AES instructions.
12///
13/// Selects which byte of `rs2` is fed into the S-box: `bs ∈ {0,1,2,3}`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(u8)]
16pub enum Rv32AesBs {
17    B0 = 0,
18    B1 = 1,
19    B2 = 2,
20    B3 = 3,
21}
22
23impl From<Rv32AesBs> for u8 {
24    #[inline(always)]
25    fn from(bs: Rv32AesBs) -> Self {
26        bs as u8
27    }
28}
29
30impl From<Rv32AesBs> for usize {
31    #[inline(always)]
32    fn from(bs: Rv32AesBs) -> Self {
33        usize::from(bs as u8)
34    }
35}
36
37impl fmt::Display for Rv32AesBs {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        fmt::Display::fmt(&(*self as u8), f)
40    }
41}
42
43impl Rv32AesBs {
44    /// Create from raw 2-bit value. Returns `None` if `bits > 3`.
45    #[inline(always)]
46    pub const fn from_bits(bits: u8) -> Option<Self> {
47        match bits {
48            0 => Some(Self::B0),
49            1 => Some(Self::B1),
50            2 => Some(Self::B2),
51            3 => Some(Self::B3),
52            _ => None,
53        }
54    }
55}
56
57/// RISC-V RV32 Zknd instructions (AES decryption)
58#[instruction]
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum Rv32ZkndInstruction<Reg> {
61    /// AES final round decryption step: InvSubBytes on one byte of rs2,
62    /// rotated to the byte lane selected by bs, XOR'd into rs1.
63    ///
64    /// `rd = rs1 ^ rol32(INV_SBOX[(rs2 >> (bs*8)) & 0xff] as u32, bs*8)`
65    Aes32Dsi {
66        rd: Reg,
67        rs1: Reg,
68        rs2: Reg,
69        bs: Rv32AesBs,
70    },
71    /// AES middle round decryption step: InvSubBytes + partial InvMixColumns
72    /// on one byte of rs2, rotated to the byte lane selected by bs, XOR'd into rs1.
73    ///
74    /// `rd = rs1 ^ rol32(InvMixColByte(INV_SBOX[(rs2 >> (bs*8)) & 0xff]), bs*8)`
75    Aes32Dsmi {
76        rd: Reg,
77        rs1: Reg,
78        rs2: Reg,
79        bs: Rv32AesBs,
80    },
81}
82
83/// Encoding layout (R-type, opcode 0x33, funct3 0x0):
84///
85/// ```text
86/// [31:30] bs       - 2-bit byte select
87/// [29:25] funct5   - 0b10101 (aes32dsi) / 0b10111 (aes32dsmi)
88/// [24:20] rs2
89/// [19:15] rs1
90/// [14:12] funct3   - 0b000
91/// [11:7]  rd
92/// [6:0]   opcode   - 0b0110011 (OP)
93/// ```
94///
95/// Ratified match/mask values (from riscv-opcodes):
96///   MATCH_AES32DSI  = 0x2a000033, MASK_AES32DSI  = 0x3e00707f
97///   MATCH_AES32DSMI = 0x2e000033, MASK_AES32DSMI = 0x3e00707f
98///
99/// `rd` and `rs1` are independent fields. The assembler convention places
100/// the accumulator in both rd and rs1 (the `rt` pattern), but the hardware
101/// does not require rd == rs1 and the decoder must not enforce it.
102#[instruction]
103impl<Reg> const Instruction for Rv32ZkndInstruction<Reg>
104where
105    Reg: [const] Register<Type = u32>,
106{
107    type Reg = Reg;
108
109    #[inline(always)]
110    fn try_decode(instruction: u32) -> Option<Self> {
111        let opcode = (instruction & 0b111_1111) as u8;
112        let rd_bits = ((instruction >> 7) & 0x1f) as u8;
113        let funct3 = ((instruction >> 12) & 0b111) as u8;
114        let rs1_bits = ((instruction >> 15) & 0x1f) as u8;
115        let rs2_bits = ((instruction >> 20) & 0x1f) as u8;
116        let funct5 = ((instruction >> 25) & 0b1_1111) as u8;
117        let bs_bits = ((instruction >> 30) & 0b11) as u8;
118
119        // R-type OP opcode only
120        if opcode != 0b0110011 {
121            None?;
122        }
123        if funct3 != 0b000 {
124            None?;
125        }
126
127        let rd = Reg::from_bits(rd_bits)?;
128        let rs1 = Reg::from_bits(rs1_bits)?;
129        let rs2 = Reg::from_bits(rs2_bits)?;
130        let bs = Rv32AesBs::from_bits(bs_bits)?;
131
132        match funct5 {
133            // aes32dsi:  bs[31:30] | 0b10101[29:25]
134            0b10101 => Some(Self::Aes32Dsi { rd, rs1, rs2, bs }),
135            // aes32dsmi: bs[31:30] | 0b10111[29:25]
136            0b10111 => Some(Self::Aes32Dsmi { rd, rs1, rs2, bs }),
137            _ => None,
138        }
139    }
140
141    #[inline(always)]
142    fn alignment() -> u8 {
143        align_of::<u32>() as u8
144    }
145
146    #[inline(always)]
147    fn size(&self) -> u8 {
148        size_of::<u32>() as u8
149    }
150}
151
152impl<Reg> fmt::Display for Rv32ZkndInstruction<Reg>
153where
154    Reg: fmt::Display,
155{
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        match self {
158            Self::Aes32Dsi { rd, rs1, rs2, bs } => {
159                write!(f, "aes32dsi {rd}, {rs1}, {rs2}, {bs}")
160            }
161            Self::Aes32Dsmi { rd, rs1, rs2, bs } => {
162                write!(f, "aes32dsmi {rd}, {rs1}, {rs2}, {bs}")
163            }
164        }
165    }
166}