Skip to main content

ab_riscv_interpreter/v/zvexx/arith/
zvexx_arith_helpers.rs

1//! Opaque helpers for ZveXx extension
2
3use crate::v::vector_registers::{VectorRegisterFile, VectorRegistersExt};
4use crate::v::zvexx::load::zvexx_load_helpers::{mask_bit, snapshot_mask};
5use crate::v::zvexx::zvexx_helpers::INSTRUCTION_SIZE;
6use crate::{ExecutionError, ProgramCounter};
7use ab_riscv_primitives::prelude::*;
8use core::fmt;
9
10/// Check that `vreg` (`vd`/`vs`) is aligned to `group_regs` and fits within `[0, 32)`
11#[inline(always)]
12#[doc(hidden)]
13pub fn check_vreg_group_alignment<Reg, Memory, PC, CustomError>(
14    program_counter: &PC,
15    vreg: VReg,
16    group_regs: u8,
17) -> Result<(), ExecutionError<Reg::Type, CustomError>>
18where
19    Reg: Register,
20    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
21{
22    let vreg_idx = vreg.to_bits();
23    if !vreg_idx.is_multiple_of(group_regs) || vreg_idx + group_regs > 32 {
24        return Err(ExecutionError::IllegalInstruction {
25            address: program_counter.old_pc(INSTRUCTION_SIZE),
26        });
27    }
28    Ok(())
29}
30
31/// Check mask-destination / source overlap constraint for compare instructions.
32///
33/// Per RVV §11.8: a mask destination register may overlap a source register group only when
34/// the source group occupies a single register (LMUL ≤ 1, i.e. `group_regs == 1`). Otherwise
35/// the encoding is reserved and raises an illegal instruction.
36#[inline(always)]
37#[doc(hidden)]
38pub fn check_mask_dest_no_overlap<Reg, Memory, PC, CustomError>(
39    program_counter: &PC,
40    vd: VReg,
41    src_base: VReg,
42    group_regs: u8,
43) -> Result<(), ExecutionError<Reg::Type, CustomError>>
44where
45    Reg: Register,
46    PC: ProgramCounter<Reg::Type, Memory, CustomError>,
47{
48    if group_regs > 1 {
49        let vd_idx = vd.to_bits();
50        let src = src_base.to_bits();
51        if vd_idx >= src && vd_idx < src + group_regs {
52            return Err(ExecutionError::IllegalInstruction {
53                address: program_counter.old_pc(INSTRUCTION_SIZE),
54            });
55        }
56    }
57    Ok(())
58}
59
60/// Read a SEW-wide element from register group `[base_reg, base_reg + group_regs)` as `u64`.
61///
62/// Element `elem_i` occupies bytes at:
63///   - register `base_reg + elem_i / elems_per_reg`
64///   - byte offset `(elem_i % elems_per_reg) * sew_bytes`
65///
66/// The value is zero-extended to `u64`.
67///
68/// # Safety
69/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
70#[inline(always)]
71pub(in super::super) unsafe fn read_element_u64<const VLENB: usize>(
72    vregs: &VectorRegisterFile<VLENB>,
73    base_reg: VReg,
74    elem_i: u32,
75    sew: Vsew,
76) -> u64 {
77    let sew_bytes = usize::from(sew.bytes_width());
78    let elems_per_reg = VLENB / sew_bytes;
79    let reg_off = elem_i as usize / elems_per_reg;
80    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
81    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
82    let reg = vregs
83        .get(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
84    // SAFETY: `byte_off + sew_bytes <= VLENB` because `byte_off` is at most
85    // `(elems_per_reg - 1) * sew_bytes = VLENB - sew_bytes`
86    let src = unsafe { reg.get_unchecked(byte_off..byte_off + sew_bytes) };
87    let mut buf = [0u8; 8];
88    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
89    unsafe { buf.get_unchecked_mut(..sew_bytes) }.copy_from_slice(src);
90    u64::from_le_bytes(buf)
91}
92
93/// Write a SEW-wide element (low `sew_bytes` of `value`) into register group
94/// `[base_reg, base_reg + group_regs)` at element index `elem_i`.
95///
96/// # Safety
97/// `base_reg + elem_i / (VLENB / sew_bytes) < 32` must hold.
98#[inline(always)]
99pub(in super::super) unsafe fn write_element_u64<const VLENB: usize>(
100    vregs: &mut VectorRegisterFile<VLENB>,
101    base_reg: VReg,
102    elem_i: u32,
103    sew: Vsew,
104    value: u64,
105) {
106    let sew_bytes = usize::from(sew.bytes_width());
107    let elems_per_reg = VLENB / sew_bytes;
108    let reg_off = elem_i as usize / elems_per_reg;
109    let byte_off = (elem_i as usize % elems_per_reg) * sew_bytes;
110    let buf = value.to_le_bytes();
111    // SAFETY: `base_reg + reg_off < 32` by caller's precondition
112    let reg = vregs
113        .get_mut(unsafe { VReg::from_bits(base_reg.to_bits() + reg_off as u8).unwrap_unchecked() });
114    // SAFETY: `byte_off + sew_bytes <= VLENB` - same argument as `read_element_u64`.
115    // `sew_bytes <= 8` for all `Vsew` variants.
116    let dst = unsafe { reg.get_unchecked_mut(byte_off..byte_off + sew_bytes) };
117    // SAFETY: `sew_bytes <= 8` for all `Vsew` variants
118    dst.copy_from_slice(unsafe { buf.get_unchecked(..sew_bytes) });
119}
120
121/// Write one mask bit (the comparison result for element `elem_i`) into register `vd`.
122///
123/// Bits are stored LSB-first: element `i` lives at byte `i / 8`, bit `i % 8`.
124/// Only the target bit is modified; all other bits are undisturbed (tail-undisturbed semantics
125/// required for mask destinations per spec §5.3).
126///
127/// # Safety
128/// `elem_i / 8 < VLENB` must hold, i.e. `elem_i < VLEN`. This is guaranteed when
129/// `elem_i < vl <= VLMAX <= VLEN`.
130#[inline(always)]
131pub(in super::super) unsafe fn write_mask_bit<const VLENB: usize>(
132    vregs: &mut VectorRegisterFile<VLENB>,
133    vd: VReg,
134    elem_i: u32,
135    result: bool,
136) {
137    let byte_idx = (elem_i / u8::BITS) as usize;
138    let bit_idx = elem_i % u8::BITS;
139    // SAFETY: `byte_idx < VLENB` by the caller's precondition
140    let byte = unsafe { vregs.get_mut(vd).get_unchecked_mut(byte_idx) };
141    if result {
142        *byte |= 1 << bit_idx;
143    } else {
144        *byte &= !(1 << bit_idx);
145    }
146}
147
148/// Operand source
149#[derive(Debug)]
150#[doc(hidden)]
151pub enum OpSrc {
152    /// Vector-vector: source register index
153    Vreg(VReg),
154    /// Vector-scalar: scalar value (sign- or zero-extended to u64)
155    Scalar(u64),
156}
157
158/// Execute a single-width element-wise arithmetic operation over `vstart..vl`.
159///
160/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew)` and returns the `u64` result (only the
161/// low `sew.bits_width()` are written back).
162///
163/// # Safety
164/// - `vd.to_bits() % group_regs == 0` and `vd.to_bits() + group_regs <= 32` (verified by caller)
165/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
166/// - `vl <= group_regs * VLENB / sew_bytes` (all `vl` elements fit within the register group)
167/// - When `vm=false`: `vd.to_bits() != 0` (vd does not overlap v0)
168#[inline(always)]
169#[doc(hidden)]
170pub unsafe fn execute_arith_op<Reg, ExtState, CustomError, F>(
171    ext_state: &mut ExtState,
172    vd: VReg,
173    vs2: VReg,
174    src: OpSrc,
175    vm: bool,
176    sew: Vsew,
177    op: F,
178) where
179    Reg: Register,
180    ExtState: VectorRegistersExt<Reg, CustomError>,
181    [(); ExtState::ELEN as usize]:,
182    [(); ExtState::VLEN as usize]:,
183    [(); ExtState::VLENB as usize]:,
184    CustomError: fmt::Debug,
185    F: Fn(u64, u64, Vsew) -> u64,
186{
187    let vl = ext_state.vl();
188    let vstart = ext_state.vstart();
189    // SAFETY: `vl <= VLMAX <= VLEN`, so `vl.div_ceil(8) <= VLENB`
190    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
191
192    for i in u32::from(vstart)..vl {
193        if !mask_bit(&mask_buf, i) {
194            continue;
195        }
196
197        // SAFETY: `vs2 % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`, so
198        // `vs2 + i / elems_per_reg < vs2 + group_regs <= 32`
199        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
200
201        let b = match src {
202            OpSrc::Vreg(vs1_base) => {
203                // SAFETY: same argument as vs2
204                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
205            }
206            OpSrc::Scalar(val) => val,
207        };
208
209        let result = op(a, b, sew);
210
211        // SAFETY: `vd % group_regs == 0` and `i < vl <= group_regs * elems_per_reg`, so
212        // `vd + i / elems_per_reg < vd + group_regs <= 32`
213        unsafe {
214            write_element_u64(ext_state.write_vregs(), vd, i, sew, result);
215        }
216    }
217
218    ext_state.mark_vs_dirty();
219    ext_state.reset_vstart();
220}
221
222/// Execute a single-width element-wise integer compare over `vstart..vl`, writing one result
223/// bit per element into the mask register `vd`.
224///
225/// `op` receives `(vs2_elem: u64, src_elem: u64, sew: Vsew) -> bool`.
226///
227/// Mask destination tail bits (indices `>= vl`) are always left undisturbed per spec §5.3,
228/// regardless of `vta`. Only bits in `vstart..vl` are written.
229///
230/// # Safety
231/// - `vs2.to_bits() % group_regs == 0` and `vs2.to_bits() + group_regs <= 32` (verified by caller)
232/// - `src` register (when `OpSrc::Vreg`) satisfies the same alignment (verified by caller)
233/// - `vl <= group_regs * VLENB / sew_bytes`
234/// - `vl <= VLEN` (so every element index fits within the mask register)
235#[inline(always)]
236#[doc(hidden)]
237pub unsafe fn execute_compare_op<Reg, ExtState, CustomError, F>(
238    ext_state: &mut ExtState,
239    vd: VReg,
240    vs2: VReg,
241    src: OpSrc,
242    vm: bool,
243    sew: Vsew,
244    op: F,
245) where
246    Reg: Register,
247    ExtState: VectorRegistersExt<Reg, CustomError>,
248    [(); ExtState::ELEN as usize]:,
249    [(); ExtState::VLEN as usize]:,
250    [(); ExtState::VLENB as usize]:,
251    CustomError: fmt::Debug,
252    F: Fn(u64, u64, Vsew) -> bool,
253{
254    let vl = ext_state.vl();
255    let vstart = ext_state.vstart();
256    // SAFETY: `vl <= VLEN`, so `vl.div_ceil(8) <= VLENB`.
257    let mask_buf = unsafe { snapshot_mask(ext_state.read_vregs(), vm, vl) };
258
259    for i in u32::from(vstart)..vl {
260        // When masked, inactive elements in the destination mask register are left undisturbed
261        // (spec §12.8: "mask register results follow mask-undisturbed policy")
262        if !mask_bit(&mask_buf, i) {
263            continue;
264        }
265
266        // SAFETY: same argument as in `execute_arith_op`
267        let a = unsafe { read_element_u64(ext_state.read_vregs(), vs2, i, sew) };
268
269        let b = match src {
270            OpSrc::Vreg(vs1_base) => {
271                // SAFETY: same argument as vs2
272                unsafe { read_element_u64(ext_state.read_vregs(), vs1_base, i, sew) }
273            }
274            OpSrc::Scalar(val) => val,
275        };
276
277        let result = op(a, b, sew);
278
279        // SAFETY: `i < vl <= VLMAX <= VLEN`, so `i / 8 < VLEN / 8 = VLENB`
280        unsafe {
281            write_mask_bit(ext_state.write_vregs(), vd, i, result);
282        }
283    }
284
285    ext_state.mark_vs_dirty();
286    ext_state.reset_vstart();
287}
288
289/// Sign-extend the low `sew.bits_width()` of `val` to a full `i64`
290#[inline(always)]
291#[doc(hidden)]
292pub fn sign_extend(val: u64, sew: Vsew) -> i64 {
293    let shift = u64::BITS - u32::from(sew.bits_width());
294    (val.cast_signed() << shift) >> shift
295}
296
297/// Mask off the upper bits of a `u64` to leave only the low `sew.bits_width()`.
298///
299/// Used for unsigned arithmetic and comparisons where only the SEW-wide portion is significant. For
300/// SEW = 64 this is a no-op (all bits are significant).
301#[inline(always)]
302#[doc(hidden)]
303pub fn sew_mask(sew: Vsew) -> u64 {
304    if u32::from(sew.bits_width()) == u64::BITS {
305        u64::MAX
306    } else {
307        (1u64 << sew.bits_width()) - 1
308    }
309}