Skip to main content

ab_riscv_primitives/instructions/rv64/zk/zkn/
zknd.rs

1//! RV64 Zknd extension
2
3#[cfg(test)]
4mod tests;
5
6use crate::instructions::Instruction;
7use crate::registers::general_purpose::Register;
8use ab_riscv_macros::instruction;
9use core::{fmt, mem};
10
11/// AES key schedule round constant number
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[repr(u8)]
14pub enum Rv64ZkndKsRnum {
15    R0 = 0x0,
16    R1 = 0x1,
17    R2 = 0x2,
18    R3 = 0x3,
19    R4 = 0x4,
20    R5 = 0x5,
21    R6 = 0x6,
22    R7 = 0x7,
23    R8 = 0x8,
24    R9 = 0x9,
25    Final = 0xA,
26}
27
28impl const From<Rv64ZkndKsRnum> for u8 {
29    #[inline(always)]
30    fn from(rnum: Rv64ZkndKsRnum) -> Self {
31        rnum as u8
32    }
33}
34
35impl const From<Rv64ZkndKsRnum> for usize {
36    #[inline(always)]
37    fn from(rnum: Rv64ZkndKsRnum) -> Self {
38        usize::from(rnum as u8)
39    }
40}
41
42impl fmt::Display for Rv64ZkndKsRnum {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        fmt::Display::fmt(&(*self as u8), f)
45    }
46}
47
48impl Rv64ZkndKsRnum {
49    /// Round constants `RC[0..=9]`, indexed by rnum (0-based).
50    /// `RC[rnum]` corresponds to FIPS 197 `Rcon[rnum+1]`.
51    const RCON: [u8; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36];
52
53    /// Create from raw bits
54    #[inline(always)]
55    pub const fn from_bits(bits: u8) -> Option<Self> {
56        if bits <= Rv64ZkndKsRnum::Final as u8 {
57            // SAFETY: The transmute is safe because `Rv64ZkndKsRnum` is `#[repr(u8)]` enum with
58            // known valid values
59            Some(unsafe { mem::transmute::<u8, Self>(bits) })
60        } else {
61            None
62        }
63    }
64
65    /// Round constant (unless final)
66    #[inline(always)]
67    pub const fn constant(self) -> Option<u8> {
68        if matches!(self, Rv64ZkndKsRnum::Final) {
69            None
70        } else {
71            Some(Self::RCON[usize::from(self)])
72        }
73    }
74}
75
76/// RISC-V RV64 Zknd instructions (AES decryption and key schedule)
77#[instruction]
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum Rv64ZkndInstruction<Reg> {
80    /// AES final round decryption: InvShiftRows + InvSubBytes, no MixColumns
81    Aes64Ds { rd: Reg, rs1: Reg, rs2: Reg },
82    /// AES middle round decryption: InvShiftRows + InvSubBytes + InvMixColumns
83    Aes64Dsm { rd: Reg, rs1: Reg, rs2: Reg },
84    /// AES inverse MixColumns on each 32-bit word of rs1
85    Aes64Im { rd: Reg, rs1: Reg },
86    /// AES key schedule step 1 (rnum in 0..=10)
87    Aes64Ks1i {
88        rd: Reg,
89        rs1: Reg,
90        rnum: Rv64ZkndKsRnum,
91    },
92    /// AES key schedule step 2
93    Aes64Ks2 { rd: Reg, rs1: Reg, rs2: Reg },
94}
95
96#[instruction]
97impl<Reg> const Instruction for Rv64ZkndInstruction<Reg>
98where
99    Reg: [const] Register<Type = u64>,
100{
101    type Reg = Reg;
102
103    #[inline(always)]
104    fn try_decode(instruction: u32) -> Option<Self> {
105        let opcode = (instruction & 0b111_1111) as u8;
106        let rd_bits = ((instruction >> 7) & 0x1f) as u8;
107        let funct3 = ((instruction >> 12) & 0b111) as u8;
108        let rs1_bits = ((instruction >> 15) & 0x1f) as u8;
109        let rs2_bits = ((instruction >> 20) & 0x1f) as u8;
110        let funct7 = ((instruction >> 25) & 0b111_1111) as u8;
111
112        match opcode {
113            // R-type: OP opcode (0x33)
114            //   aes64ds:  funct7=0b0011101, funct3=0 -> MATCH=0x3a000033
115            //   aes64dsm: funct7=0b0011111, funct3=0 -> MATCH=0x3e000033
116            //   aes64ks2: funct7=0b0111111, funct3=0 -> MATCH=0x7e000033
117            0b0110011 => {
118                if funct3 != 0b000 {
119                    None?;
120                }
121                let rd = Reg::from_bits(rd_bits)?;
122                let rs1 = Reg::from_bits(rs1_bits)?;
123                let rs2 = Reg::from_bits(rs2_bits)?;
124                match funct7 {
125                    0b0011101 => Some(Self::Aes64Ds { rd, rs1, rs2 }),
126                    0b0011111 => Some(Self::Aes64Dsm { rd, rs1, rs2 }),
127                    0b0111111 => Some(Self::Aes64Ks2 { rd, rs1, rs2 }),
128                    _ => None,
129                }
130            }
131            // I-type: OP-IMM opcode (0x13), funct3=0b001
132            //   aes64im:   imm[11:0]=0x300  (funct7=0b0011000, rs2=0b00000) -> MATCH=0x30001013
133            //   aes64ks1i: imm[11:5]=0b0011000, imm[4]=1, imm[3:0]=rnum     -> MATCH=0x31001013+
134            0b0010011 => {
135                if funct3 != 0b001 {
136                    None?;
137                }
138                let rd = Reg::from_bits(rd_bits)?;
139                let rs1 = Reg::from_bits(rs1_bits)?;
140                let imm12 = instruction >> 20;
141                if imm12 == 0x300 {
142                    Some(Self::Aes64Im { rd, rs1 })
143                } else if (imm12 >> 5) == 0b0011000 && (imm12 & 0b1_0000) != 0 {
144                    // bits[11:5]=0b0011000, bit[4]=1, bits[3:0]=rnum
145                    let rnum = Rv64ZkndKsRnum::from_bits((imm12 & 0xf) as u8)?;
146                    Some(Self::Aes64Ks1i { rd, rs1, rnum })
147                } else {
148                    None
149                }
150            }
151            _ => None,
152        }
153    }
154
155    #[inline(always)]
156    fn alignment() -> u8 {
157        align_of::<u32>() as u8
158    }
159
160    #[inline(always)]
161    fn size(&self) -> u8 {
162        size_of::<u32>() as u8
163    }
164}
165
166impl<Reg> fmt::Display for Rv64ZkndInstruction<Reg>
167where
168    Reg: fmt::Display,
169{
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        match self {
172            Self::Aes64Ds { rd, rs1, rs2 } => write!(f, "aes64ds {rd}, {rs1}, {rs2}"),
173            Self::Aes64Dsm { rd, rs1, rs2 } => write!(f, "aes64dsm {rd}, {rs1}, {rs2}"),
174            Self::Aes64Im { rd, rs1 } => write!(f, "aes64im {rd}, {rs1}"),
175            Self::Aes64Ks1i { rd, rs1, rnum } => write!(f, "aes64ks1i {rd}, {rs1}, {rnum}"),
176            Self::Aes64Ks2 { rd, rs1, rs2 } => write!(f, "aes64ks2 {rd}, {rs1}, {rs2}"),
177        }
178    }
179}