Skip to main content

ab_riscv_interpreter/v/zve64x/arith/
zve64x_arith_helpers.rs

1//! Opaque helpers for Zve64x extension
2
3use crate::v::vector_registers::VectorRegistersExt;
4use crate::v::zve64x::load::zve64x_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zve64x::zve64x_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, InterpreterState, ProgramCounter, VirtualMemory};
7use ab_riscv_primitives::instructions::v::Vsew;
8use ab_riscv_primitives::registers::general_purpose::Register;
9use ab_riscv_primitives::registers::vector::VReg;
10use core::fmt;
11
12/// Check that `vreg` (`vd`/`vs`) is aligned to `group_regs` and fits within `[0, 32)`
13#[inline(always)]
14#[doc(hidden)]
15pub fn check_vreg_group_alignment<Reg, ExtState, Memory, PC, IH, CustomError>(
16    state: &InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
17    vreg: VReg,
18    group_regs: u8,
19) -> Result<(), ExecutionError<Reg::Type, CustomError>>
20where
21    Reg: Register,
22    [(); Reg::N]:,
23    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
24{
25    let vd_idx = vreg.bits();
26    if !vd_idx.is_multiple_of(group_regs) || vd_idx + group_regs > 32 {
27        return Err(ExecutionError::IllegalInstruction {
28            address: state.instruction_fetcher.old_pc(INSTRUCTION_SIZE),
29        });
30    }
31    Ok(())
32}
33
34/// Read a SEW-wide element from register group `[base_reg, base_reg + group_regs)` as `u64`.
35///
36/// Element `elem_i` occupies bytes at:
37///   - register `base_reg + elem_i / elems_per_reg`
38///   - byte offset `(elem_i % elems_per_reg) * sew_bytes`
39///
40/// The value is zero-extended to `u64`.
41///
42/// # Safety
43/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
44#[inline(always)]
45pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
46    vreg: &[[u8; VLENB]; 32],
47    base_reg: usize,
48    elem_i: u32,
49    sew: Vsew,
50) -> u64 {
51    let sew_bytes = usize::from(sew.bytes());
52    let elems_per_reg = VLENB / sew_bytes;
53    let reg_off = elem_i as usize / elems_per_reg;
54    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
55    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
56    let reg = unsafe { vreg.get_unchecked(base_reg + reg_off) };
57    // SAFETY: `byte_off + sew_bytes <= VLENB` because `byte_off` is at most
58    // `(elems_per_reg - 1) * sew_bytes = VLENB - sew_bytes`
59    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
60    let mut buf = [0u8; 8];
61    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
62    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
63    u64::from_le_bytes(buf)
64}
65
66/// Write a SEW-wide element (low `sew_bytes` of `value`) into register group
67/// `[base_reg, base_reg + group_regs)` at element index `elem_i`.
68///
69/// # Safety
70/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
71#[inline(always)]
72pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
73    vreg: &mut [[u8; VLENB]; 32],
74    base_reg: u8,
75    elem_i: u32,
76    sew: Vsew,
77    value: u64,
78) {
79    let sew_bytes = usize::from(sew.bytes());
80    let elems_per_reg = VLENB / sew_bytes;
81    let reg_off = elem_i as usize / elems_per_reg;
82    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
83    let buf = value.to_le_bytes();
84    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
85    let reg = unsafe { vreg.get_unchecked_mut(usize::from(base_reg) + reg_off) };
86    // SAFETY: `byte_off + sew_bytes <= VLENB` - same argument as `read_element_u64`.
87    // `sew_bytes <= 8` for all `Vsew` variants.
88    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
89    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
90    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
91}
92
93/// Write one mask bit (the comparison result for element `elem_i`) into register `vd`.
94///
95/// Bits are stored LSB-first: element `i` lives at byte `i / 8`, bit `i % 8`.
96/// Only the target bit is modified; all other bits are undisturbed (tail-undisturbed semantics
97/// required for mask destinations per spec §5.3).
98///
99/// # Safety
100/// `elem_i / 8 < VLENB` must hold, i.e. `elem_i < VLEN`. This is guaranteed when
101/// `elem_i < vl <= VLMAX <= VLEN`.
102#[inline(always)]
103pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
104    vreg: &mut [[u8; VLENB]; 32],
105    vd: VReg,
106    elem_i: u32,
107    result: bool,
108) {
109    let byte_idx = (elem_i / u8::BITS) as usize;
110    let bit_idx = elem_i % u8::BITS;
111    // SAFETY: `byte_idx < VLENB` by the caller's precondition
112    let byte = unsafe {
113        vreg.get_unchecked_mut(usize::from(vd.bits()))
114            .get_unchecked_mut(byte_idx)
115    };
116    if result {
117        *byte |= 1 << bit_idx;
118    } else {
119        *byte &= !(1 << bit_idx);
120    }
121}
122
123/// Operand source
124#[derive(Debug)]
125#[doc(hidden)]
126pub enum OpSrc {
127    /// Vector-vector: source register index
128    Vreg(u8),
129    /// Vector-scalar: scalar value (sign- or zero-extended to u64)
130    Scalar(u64),
131}
132
133/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
134///
135/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result (only the
136/// low `sew.bits()` are written back).
137///
138/// # Safety
139/// - `vd.bits() % group_regs == 0` and `vd.bits() + group_regs <= 32` (verified by caller)
140/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
141/// - `vl <= group_regs * VLENB / sew_bytes` (all `vl` elements fit within the register group)
142/// - When `vm=false`: `vd.bits() != 0` (vd does not overlap v0)
143#[inline(always)]
144#[expect(clippy::too_many_arguments, reason = "Internal API")]
145#[doc(hidden)]
146pub unsafe fn execute_arith_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
147    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
148    vd: VReg,
149    vs2: VReg,
150    src: OpSrc,
151    vm: bool,
152    vl: u32,
153    vstart: u32,
154    sew: Vsew,
155    op: F,
156) where
157    Reg: Register,
158    [(); Reg::N]:,
159    ExtState: VectorRegistersExt<Reg, CustomError>,
160    [(); ExtState::ELEN as usize]:,
161    [(); ExtState::VLEN as usize]:,
162    [(); ExtState::VLENB as usize]:,
163    Memory: VirtualMemory,
164    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
165    CustomError: fmt::Debug,
166    F: Fn(u64, u64, Vsew) -> u64,
167{
168    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
169    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
170
171    let vd_base = vd.bits();
172    let vs2_base = vs2.bits();
173
174    for i in vstart..vl {
175        if !mask_bit(&mask_buf, i) {
176            continue;
177        }
178
179        // SAFETY: `vs2_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
180        // so `vs2_base + i / elems_per_reg < vs2_base + group_regs <= 32`
181        let a =
182            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
183
184        let b = match &src {
185            OpSrc::Vreg(vs1_base) => {
186                // SAFETY: same argument as vs2
187                unsafe {
188                    read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
189                }
190            }
191            OpSrc::Scalar(val) => *val,
192        };
193
194        let result = op(a, b, sew);
195
196        // SAFETY: `vd_base % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`,
197        // so `vd_base + i / elems_per_reg < vd_base + group_regs <= 32`
198        unsafe {
199            write_element_u64(state.ext_state.write_vreg(), vd_base, i, sew, result);
200        }
201    }
202
203    state.ext_state.mark_vs_dirty();
204    state.ext_state.reset_vstart();
205}
206
207/// Execute a single-width element-wise integer compare over `vstart..vl`, writing one result
208/// bit per element into the mask register `vd`.
209///
210/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew) -> bool`.
211///
212/// Mask destination tail bits (indices `>= vl`) are always left undisturbed per spec §5.3,
213/// regardless of `vta`. Only bits in `vstart..vl` are written.
214///
215/// # Safety
216/// - `vs2.bits() % group_regs == 0` and `vs2.bits() + group_regs <= 32` (verified by caller)
217/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
218/// - `vl <= group_regs * VLENB / sew_bytes`
219/// - `vl <= VLEN` (so every element index fits within the mask register)
220#[inline(always)]
221#[expect(clippy::too_many_arguments, reason = "Internal API")]
222#[doc(hidden)]
223pub unsafe fn execute_compare_op<Reg, ExtState, Memory, PC, IH, CustomError, F>(
224    state: &mut InterpreterState<Reg, ExtState, Memory, PC, IH, CustomError>,
225    vd: VReg,
226    vs2: VReg,
227    src: OpSrc,
228    vm: bool,
229    vl: u32,
230    vstart: u32,
231    sew: Vsew,
232    op: F,
233) where
234    Reg: Register,
235    [(); Reg::N]:,
236    ExtState: VectorRegistersExt<Reg, CustomError>,
237    [(); ExtState::ELEN as usize]:,
238    [(); ExtState::VLEN as usize]:,
239    [(); ExtState::VLENB as usize]:,
240    Memory: VirtualMemory,
241    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
242    CustomError: fmt::Debug,
243    F: Fn(u64, u64, Vsew) -> bool,
244{
245    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
246    let mask_buf = unsafe { snapshot_mask(state.ext_state.read_vreg(), vm, vl) };
247
248    let vs2_base = vs2.bits();
249
250    for i in vstart..vl {
251        // When masked, inactive elements in the destination mask register are left undisturbed
252        // (spec §12.8: "mask register results follow mask-undisturbed policy")
253        if !mask_bit(&mask_buf, i) {
254            continue;
255        }
256
257        // SAFETY: same argument as in `execute_arith_op`
258        let a =
259            unsafe { read_element_u64(state.ext_state.read_vreg(), usize::from(vs2_base), i, sew) };
260
261        let b = match &src {
262            OpSrc::Vreg(vs1_base) => {
263                // SAFETY: same argument as vs2
264                unsafe {
265                    read_element_u64(state.ext_state.read_vreg(), usize::from(*vs1_base), i, sew)
266                }
267            }
268            OpSrc::Scalar(val) => *val,
269        };
270
271        let result = op(a, b, sew);
272
273        // SAFETY: `i < vl <= VLMAX <= VLEN`, so `i / 8 < VLEN / 8 = VLENB`
274        unsafe {
275            write_mask_bit(state.ext_state.write_vreg(), vd, i, result);
276        }
277    }
278
279    state.ext_state.mark_vs_dirty();
280    state.ext_state.reset_vstart();
281}
282
283/// Sign-extend the low `sew.bits()` of `val` to a full `i64`
284#[inline(always)]
285#[doc(hidden)]
286pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
287    let shift = u64::BITS - u32::from(sew.bits());
288    (val.cast_signed() << shift) >> shift
289}
290
291/// Mask off the upper bits of a `u64` to leave only the low `sew.bits()`.
292///
293/// Used for unsigned arithmetic and comparisons where only the SEW-wide portion is significant. For
294/// SEW = 64 this is a no-op (all bits are significant).
295#[inline(always)]
296#[doc(hidden)]
297pub fn sew_mask(sew: Vsew) -> u64 {
298    if u32::from(sew.bits()) == u64::BITS {
299        u64::MAX
300    } else {
301        (1u64 << sew.bits()) - 1
302    }
303}